In [None]:
import gc, os, glob, numpy as np, matplotlib.pyplot as plt, shapely.ops, \
    scipy.ndimage as ndimage, shapely.geometry as geom, shapely, \
    matplotlib.gridspec as gridspec, imageio, joblib as jl, pandas as pd
from tqdm import tqdm_notebook as tqdm
from skimage.measure import label
import skimage.draw
import scipy.ndimage

In [None]:
DEBUG = False
PLOT_EXAMPLES = False #True
REPETITIONS = 10

### Load data

In [None]:
fnameP = 'Y:/wavelet/rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd/'
fnameP+= 'rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.npy'
fnameP = '\\\\?\\' + fnameP.replace('/', '\\')

fnameA = 'Y:/wavelet/rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd/'
fnameA+= 'rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.npy'
fnameA = '\\\\?\\' + fnameA.replace('/', '\\')

fnames = [fnameP, fnameA]

Hs = []
arrs = []

for fname in fnames:
    if fname is None:
        Hs.append(None)
        arrs.append(None)
    else:
        arr = np.load(fname).copy()
        arr = arr[~np.any(np.isnan(arr), axis=1)]
        H, _, _ = np.histogram2d(arr[:,0], arr[:,1], bins=(200,200))
        H = np.clip(H, 0, np.percentile(H, 99))

        arrs.append(arr)
        Hs.append(H)

In [None]:
fnameP = 'rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.filtered2.npy'
fnameA = 'rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.filtered2.npy'
fnames = [fnameP, fnameA]

ANTERIOR = 1
POSTERIOR = 0

In [None]:
fnames = [[x for x in glob.glob('\\\\?\\Z:\\behavior\\*\\wavelet\\{}'.format(fn)) if 'RIG' not in x] for fn in fnames]
len(fnames[0]), len(fnames[1])

In [None]:
fnameP = 'Y:/wavelet/rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd/'
fnameP+= 'rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.npy'
fnameP = '\\\\?\\' + fnameP.replace('/', '\\')

fnameA = 'Y:/wavelet/rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd/'
fnameA+= 'rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.npy'
fnameA = '\\\\?\\' + fnameA.replace('/', '\\')

fnamesShared = [fnameP, fnameA]

fnamesWatershed = [x.replace('.npy', '.smoothed.watershed.npy') for x in fnamesShared]
fnamesWatershed = [(x if os.path.exists(x) else None) for x in fnamesWatershed]
arrWatershed = [np.load(x) for x in fnamesWatershed]

In [None]:
def toNumber(x):
    try:
        return int(x)
    except:
        return -1
    
def loadLabels(fnameLabels):
    txtClusterLabels = ''
    with open(fnameLabels, 'r') as f:
        txtClusterLabels = f.read()
    clusterLabels = {}
    curLabel = ''
    for line in txtClusterLabels.split('\n'):
        if ':' in line:
            curLabel = line[:line.find(':')]
        elif len(line.strip()) > 0:
            clusterLabels[curLabel] = [toNumber(x) for x in line.split(',') if toNumber(x) >= 0]
    return clusterLabels

fnameClusterLabelsA = '\\\\?\\Y:\\wavelet\\clips\\rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd\\cluster_names.txt'
fnameClusterLabelsP = '\\\\?\\Y:\\wavelet\\clips\\rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd\\cluster_names.txt'
fnamesLabels = (fnameClusterLabelsP, fnameClusterLabelsA)

clusterLabels = (loadLabels(fnamesLabels[0]), loadLabels(fnamesLabels[1]))
clusterLabelsUnique = list(set(list(clusterLabels[0].keys()) + list(clusterLabels[1].keys())))
clusterLabelsUnique = [x for x in clusterLabelsUnique if x not in ['noisy',]]

In [None]:
arrClusters = [[np.load(x.replace('.filtered2', '.clusters'))[:,0] for x in tqdm(fn, leave=False)] for fn in fnames]

In [None]:
# Get events
def getEventsOld(a, cls):
    eventIdxs = []
    eventDuration = 0
    for i in range(a.shape[0]):
        if a[i] in cls and eventDuration < 12:
            eventDuration += 1
            if eventDuration == 12: # Reached min-duration threshold of 240 ms
                lastEventI = i
                eventIdxs.append(i)
        elif a[i] not in cls:
            eventDuration = 0
    return eventIdxs

In [None]:
# Get events
def getEvents(a, cls):
    return np.argwhere(scipy.ndimage.binary_erosion(np.isin(a, cls), iterations=6))[:,0]

In [None]:
fnamesCroprot = [glob.glob(os.path.abspath(os.path.join(
    os.path.dirname(fn), '../croprot/*_img.npy')))[0] for fn in tqdm(fnames[0], leave=False)]

In [None]:
fnamesCroprotIdx = [glob.glob(os.path.join(os.path.dirname(x), 
    '*_dlc_abs_filt_interp_mvmt_noborder.idx.npy'))[0] for x in fnamesCroprot]

In [None]:
arrCroprotIdx = [np.load(x) for x in fnamesCroprotIdx]

In [None]:
def getClips(fname, starts, numframes):
    # Determine number of frames in file
    v = np.memmap(fname, mode='r', dtype=np.uint8)
    N = int(v.size / (200 * 200))
    del v
    v = np.memmap(fname, mode='r', dtype=np.uint8, shape=(N, 200, 200))
    cls = np.zeros((len(starts), numframes, 200, 200), dtype=np.uint8)
    for i, start in enumerate(starts):
        cls[i, :, :, :] = v[start:(start+numframes),:,:]
    # Close memmap'ed file
    del v
    # Return
    return cls

In [None]:
examplesStable = []
for rid in range(len(fnamesCroprot)):
    fnameMat = fnamesCroprot[rid].replace('_img.npy', '') + '_mat.npy'
    A = np.memmap(fnameMat, mode='r', dtype=np.double)
    N = A.size // 4
    del A; A = np.memmap(fnameMat, mode='r', dtype=np.double, shape=(N, 4))
    tmp = pd.DataFrame(A[:,3]).diff(1).abs().values[:,0]
    tmp = np.minimum(np.abs(360 -  tmp), tmp)
    tmp = pd.DataFrame(tmp).rolling(window=100).apply(np.nanmax).values[:,0]
    exampleStable = tmp < 60
    examplesStable.append(exampleStable)

### Plot cluster overview/outlines

In [None]:
def maskToPerimeter(mask):
    mask[0,:] = False
    mask[-1,:] = False
    mask[:,0] = False
    mask[:,-1] = False
    
    polys = [geom.Polygon([[_x + dx, _y + dy] for dx, dy in [[0,0],[0,1],[1,1],[1,0]]]) for _x, _y in \
         np.argwhere(mask)]
    a = shapely.ops.cascaded_union(polys).exterior.coords.xy
    xy = np.hstack((np.array(a[0])[:,np.newaxis], np.array(a[1])[:,np.newaxis]))
    
    return xy

In [None]:
import colorcet
import matplotlib.colors
import matplotlib.cm as cm

class customColormap(matplotlib.colors.LinearSegmentedColormap):
    def __init__(self, *args, **kwargs):
        self.baseCM = cm.get_cmap('cet_CET_L17')
        self.N = self.baseCM.N
    def __call__(self, r, *args, **kwargs):
        def mapColor(x):
            _c = self.baseCM(0.0, *args, **kwargs)
            lim = 0.10
            if x >= lim:
                return self.baseCM((x-lim)/(1.0 - lim), *args, **kwargs)
            else: 
                if isinstance(_c[0], float):
                    z = (x / 0.05)
                    a = _c[0] * z + 1.0 * (1 - z)
                    b = _c[1] * z + 1.0 * (1 - z)
                    c = _c[2] * z + 1.0 * (1 - z)
                    _c = (a, b, c, 1.0)
                    return _c
                else:
                    z = (x / lim)
                    a = int(_c[0] * z + 255.0 * (1 - z))
                    b = int(_c[1] * z + 255.0 * (1 - z))
                    c = int(_c[2] * z + 255.0 * (1 - z))
                    _c = (a, b, c, 255)
                    return _c
        if r.ndim == 2:
            _c = self.baseCM(0.0, *args, **kwargs)
            cs = np.array([[mapColor(y) for y in x] for x in r], 
                dtype=np.float64 if isinstance(_c[0], float) else np.uint8)
            return cs
        elif r.ndim == 1:
            _c = self.baseCM(0.0, *args, **kwargs)
            cs = np.array([mapColor(y) for y in r], 
                dtype=np.float64 if isinstance(_c[0], float) else np.uint8)
            return cs
        else:
            print('!!')

In [None]:
CLUSTER_COLORS = {
    'right-leg': '#a7dcfa',
    'left-leg': '#f0c566',
    'one-leg-after-other': '#7a988c',
    'both-legs': '#a0718b',
    'walk': '#66a3d3',
    'bend-abdomen': '#e69e66',
    'stabilimentum': '#e0afca',
    'extrude': '#66c5ab',
    'extrude-slow': '#f6ef8e',
    'stationary': '#999999',
    'stationary-anterior': '#c2c2c2',
    'stationary-posterior': '#c2c2c2'
}

CLUSTER_LABELS = {
    'right-leg': 'Left Leg',
    'left-leg': 'Right Leg',
    'one-leg-after-other': 'Alternating Legs',
    'both-legs': 'Both Legs (Rotation)',
    'walk': 'Walk',
    'bend-abdomen': 'Anchor',
    'stabilimentum': 'Stabilimentum',
    'extrude': 'Silk pull (Fast)',
    'extrude-slow': 'Silk pull (Slow)',
    'stationary': 'Stationary',
    'stationary-anterior': 'Stationary Anterior',
    'stationary-posterior': 'Stationary Posterior'
}

In [None]:
def plotReference(ax, anterior, clusterHighlight=None):
    _cm = 'gray' # customColormap()
    
    # Plot density
    ax.imshow(Hs[int(anterior)], cmap=_cm)

    for k, clusterIDsKey in enumerate([x for x in clusterLabels[int(anterior)] if False or x != 'noisy']):
        clusterIDs = clusterLabels[int(anterior)][clusterIDsKey]
        # Plot clusters to highlight
        mask = None
        for exID, clusterID in enumerate(clusterIDs):
            _mask = (arrWatershed[int(anterior)][:,:,0,1] == clusterID)
            # Merge masks
            mask = _mask if mask is None else (mask | _mask)
        
        # Split masks into contiguous submasks
        maskLabeled = label(ndimage.binary_dilation(ndimage.binary_erosion(
            mask, iterations=1), iterations=1))
        maskLabeledBg = np.max([np.sum(maskLabeled == x) for x in np.unique(maskLabeled)])
        isLabeled = False
        for z, maskID in enumerate(np.unique(maskLabeled)):
            if np.sum(maskLabeled == maskID) != maskLabeledBg:
                mask = (maskLabeled == maskID)
                mask[:25,:] = False
                mask[175:,:] = False
                mask[:,:25] = False
                mask[:,175:] = False
                pts = maskToPerimeter(mask)
                if pts is not None:
                    if clusterHighlight == clusterIDsKey:
                        ax.fill(pts[:,1], pts[:,0], \
                            edgecolor=CLUSTER_COLORS[clusterIDsKey], 
                            facecolor=CLUSTER_COLORS[clusterIDsKey] + '66', 
                            label=CLUSTER_LABELS[clusterIDsKey] if not isLabeled else None,
                            zorder=100)
                    else:
                        ax.fill(pts[:,1], pts[:,0], facecolor = '#00000000',\
                            edgecolor=CLUSTER_COLORS[clusterIDsKey],
                            label=CLUSTER_LABELS[clusterIDsKey] if not isLabeled else None)
                    ax.plot(pts[:,1], pts[:,0], linewidth=2, \
                        color=CLUSTER_COLORS[clusterIDsKey])
                    isLabeled = True

    # Finish first subplot
    leg = ax.legend(frameon=False, fontsize=12, 
        bbox_to_anchor=(0,0,1.9, 1))
    for text in leg.get_texts():
        plt.setp(text, color = 'white')
    ax.set_xlim(25, 175)
    ax.set_ylim(25, 175)
    ax.set_axis_off()
    ax.set_title(['Posterior t-SNE Embedding', 'Anterior t-SNE Embedding'][int(anterior)], color='white')

### Manually-selected examplars

In [None]:
CLUSTER_EXAMPLES = {
    ('anterior', 'left-leg'): [ 
        (3, 647681),
        (7, 786043),
        (13, 463875),
        (8, 476809),
        (6, 464874),
        (3, 502257)],
    ('anterior', 'right-leg'): [
        (9, 153590),
        (11, 1094661),
        (16, 422461),
        (5, 348443),
        (10, 780815),
        (12, 542387)],
    ('anterior', 'both-legs'): [
        (7, 768929),
        (9, 643407),
        (10, 670106),
        (11, 915742),
        (13, 657480),
        (15, 508988)],
    ('anterior', 'one-leg-after-other'): [
        (4, 826754),
        (7, 823370),
        (3, 471456),
        (11, 1077682),
        (12, 861423),
        (17, 564971)],
    ('anterior', 'walk'): [
        (11, 1037796),
        (17, 277974),
        (20, 311542),
        (19, 567496),
        (1, 600553),
        (3, 1968494)],
    ('anterior', 'stabilimentum'): [
        (10, 1167589),
        (12, 1188057),
        (17, 800233),
        (18, 1117439),
        (19, 1237819),
        (2, 1656160)],
    ('anterior', 'bend-abdomen'): [
        (10, 804870),
        (12, 885371),
        (14, 1055637),
        (17, 551514),
        (19, 1109645),
        (7, 862980)],
    ('anterior', 'stationary'): [
        (6, 259682),
        (5, 528479),
        (8, 356464),
        (9, 674068),
        (10, 318526),
        (11, 1017133)],
    ('anterior', 'stationary-anterior'): [
        (3, 2198400),
        (5, 623095),
        (6, 676863),
        (19, 1080384),
        (20, 470372),
        (13, 420339)],
    ('posterior', 'stabilimentum'): [
        (0, 1086445),
        (2, 1660838),
        (3, 2271448),
        (4, 1130162),
        (6, 825717),
        (12, 1186401)],
    ('posterior', 'extrude'): [
        (2, 1492653),
        (4, 923627),
        (5, 535218),
        (6, 695989),
        (7, 824777),
        (8, 384677)],
    ('posterior', 'bend-abdomen'): [
        (7, 850778),
        (9, 483474),
        (10, 858603),
        (18, 1066742),
        (20, 465958),
        (17, 435601)],
    ('posterior', 'extrude-slow'): [
        (18, 699113),
        (11, 664659),
        (14, 65626),
        (17, 436892),
        (0, 788133),
        (16, 1031830),
        (5, 507437)],
    ('posterior', 'stationary'): [
        (5, 685199),
        (4, 300144),
        (9, 250322),
        (10, 245729),
        (11, 294414),
        (12, 47400)],
    ('posterior', 'stationary-posterior'): [
        (2, 1357350),
        (6, 556391),
        (12, 809888),
        (15, 816500),
        (2, 1317614),
        (7, 741093)]
}

### Plot

In [None]:
# Obtain clips for a given cluster
def getEventIdxsForCluster(anterior, cluster, repetitions):
    if isinstance(anterior, str):
        anterior = (anterior == 'anterior')
        
    if PLOT_EXAMPLES:
        return CLUSTER_EXAMPLES[('anterior' if anterior else 'posterior', cluster)]
    
    eventIdxsAllrec = []

    numTotal = 6 * repetitions

    nrec = len(arrClusters[int(anterior)])
    numEventPerRec = np.zeros(nrec, dtype=int)
    for i in range(numTotal):
        numEventPerRec[np.argmin(numEventPerRec)] += 1
    numEventPerRec = np.random.permutation(numEventPerRec)

    for recid in tqdm(np.arange(nrec), leave=False, desc='obtaining event IDs'):
        _cl = arrClusters[int(anterior)][recid].copy()
        _cl[~examplesStable[recid][arrCroprotIdx[recid]]] = -1

        eventIdxs = getEvents(_cl, clusterLabels[anterior][cluster])
        eventIdxs = np.argwhere(arrCroprotIdx[recid])[:,0][eventIdxs]
        eventIdxsSS = np.random.choice(eventIdxs, replace=False, size=numEventPerRec[recid])
        for x in eventIdxsSS:
            eventIdxsAllrec.append((recid, x))
    return eventIdxsAllrec

In [None]:
eventIdxs = getEventIdxsForCluster('anterior', 'walk', 1)

In [None]:
recid = 0
clips = np.concatenate([getClips(fnamesCroprot[recid], 
    [fr - 25,], 50) for recid, fr in eventIdxs])

In [None]:
CIRCLE_MASK = np.full((200, 200), False, dtype=np.bool)
rr, cc = skimage.draw.circle(100, 100, 80, shape=(200, 200))
CIRCLE_MASK[rr, cc] = True

In [None]:
# ...
def renderMovieSegment(eventIdxs, repetitions = 4, highlightCluster = None):
    clips = np.concatenate([getClips(fnamesCroprot[recid], 
        [fr - 25,], 50) for recid, fr in eventIdxs])
    
    movie = []
    
    for repetition in tqdm(range(repetitions), leave=False, desc='repetitions'):
        for frameid in tqdm(range(clips.shape[1]), leave=False, desc='frameid'):
            clipsoffset = repetition * 6
            # Figure
            fig = plt.figure(constrained_layout=True, figsize=(16, 7), facecolor='#000000')
            spec = gridspec.GridSpec(ncols=4, nrows=2, figure=fig, 
                                     width_ratios=(1.3, 1, 1, 1))
            
            ax = fig.add_subplot(spec[0, 0])
            if highlightCluster is not None and highlightCluster[0] == 'anterior':
                plotReference(ax, anterior=True, clusterHighlight=highlightCluster[1],)
            else:
                plotReference(ax, anterior=True)

            ax = fig.add_subplot(spec[1, 0])
            if highlightCluster is not None and highlightCluster[0] == 'posterior':
                plotReference(ax, anterior=False, clusterHighlight=highlightCluster[1],)
            else:
                plotReference(ax, anterior=False)

            ax = fig.add_subplot(spec[:, 1:])
            for sampleid in range(6):
                icol = sampleid % 3
                irow = int(sampleid / 3.0)
                _img = clips[sampleid + repetition * 6, frameid, :, :]
                _img[~CIRCLE_MASK] = 0
                ax.imshow(_img, 
                    extent=(200 * icol, 200 * (1+icol), 
                            200 * irow, 200 * (1+irow)), \
                    aspect=True, cmap='gray')
                # Draw labels
                ax.text(200 * icol + 10, 200 * irow + 180, 'Clip {}'.format(
                    (1-irow) * 3 + icol + 1 + clipsoffset), color='white')
                ax.text(200 * icol + 190, 200 * irow + 180, '(ventral view)'.format(
                    (1-irow) * 3 + icol + 1 + clipsoffset), color='white', ha='right', alpha=0.5)
                ax.text(200 * icol + 10, 200 * irow + 166, '({})'.format(
                    CLUSTER_LABELS[highlightCluster[1]]), color=CLUSTER_COLORS[highlightCluster[1]])
                # Draw time axis
                ax.plot([160 + 200 * icol, 200 + 200 * icol], 
                        [20 + 200 * irow, 20 + 200 * irow], color='white', zorder=100)
                ax.plot([160 + 200 * icol, 160 + 200 * icol], 
                        [20 + 200 * irow - 3, 20 + 200 * irow + 3], color='white', zorder=100)
                ax.plot([200 + 200 * icol, 200 + 200 * icol], 
                        [20 + 200 * irow - 3, 20 + 200 * irow + 3], color='white', zorder=100)
                ax.text(145 + 200 * icol, 7 + 200 * irow, '–0.5s', zorder=100, color='white')
                ax.text(185 + 200 * icol, 7 + 200 * irow, '+0.5s', zorder=100, color='white')
                _x = (frameid * 0.02) * 200 + (1 - frameid * 0.02) * 160 + 200 * icol
                ax.plot([_x, _x], 
                        [20 + 200 * irow - 6, 20 + 200 * irow + 6], 
                        color='red', zorder=200, linewidth=3)
                # Draw debug information, if requested
                if DEBUG:
                    ax.text(200 * icol + 10, 200 * irow + 152, '{}'.format(
                        eventIdxs[sampleid + repetition * 6]), color='white')

            ax.set_xlim(0, 200 * 3)
            ax.set_ylim(0, 200 * 2)
            ax.set_axis_off()
            
            # Export to image
            fig.canvas.draw()
            data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plt.close(fig)
            del fig
            gc.collect()
            movie.append(data)
    # Done!
    return movie

In [None]:
clustersToHighlight = [('anterior', x) for x in clusterLabels[ANTERIOR]] + \
    [('posterior', x) for x in clusterLabels[POSTERIOR]]
clustersToHighlight = [x for x in clustersToHighlight if x[1] not in ['noisy',]]

In [None]:
movies = jl.Parallel(n_jobs=len(clustersToHighlight))(jl.delayed(renderMovieSegment)(
    getEventIdxsForCluster(highlightCl[0], highlightCl[1], repetitions = REPETITIONS), repetitions = REPETITIONS, 
    highlightCluster = highlightCl) for highlightCl in tqdm(clustersToHighlight, leave=False))

In [None]:
wr = imageio.get_writer('C:/Users/acorver/Desktop/cluster_animations{}{}_{}reps.mp4'.format(
    '_debug' if DEBUG else '', '_examples' if PLOT_EXAMPLES else '', REPETITIONS), fps=25)
for movid in range(len(movies)):
    for frid in range(len(movies[movid])):
        for it in range(10 if ((frid+1)%50)==0 else 1):
            wr.append_data(movies[movid][frid])

wr.close()