### Imports

In [None]:
import os, glob, numpy as np, pandas as pd, matplotlib.pyplot as plt, json, \
    colorcet, matplotlib.colors as mplcol, gc, joblib as jl, imageio
from simplification.cutil import simplify_coords_vw, simplify_coords_vw_idx
import matplotlib.gridspec as gridspec
import warnings
from tqdm import tqdm_notebook as tqdm
warnings.filterwarnings('ignore')

### Specify recordings to use

In [None]:
# We use an arbitrary recording as illustration
#REC = '4-13-19-a'
REC = '4-5-19-a'
fnamePos = glob.glob('Z:\\behavior\\{}\\croprot\\*_dlc_position_orientation.npy'.format(REC))[0]
fnamePos

In [None]:
# Get the manual stage annotation info
fnameInfo = os.path.abspath(os.path.join(os.path.dirname(fnamePos), '../recording.json'))
fnameInfo

#### Load data

In [None]:
fnameRecordingInfo = os.path.join(os.path.dirname(os.path.dirname(fnamePos)), 'recording.json')

def loadJSON(x):
    if os.path.exists(x):
        with open(x, 'r') as f:
            return json.load(f)
    else:
        return None
    
recordingInfo = loadJSON(fnameRecordingInfo)

In [None]:
# Fill in missing stage information, if necessary
s = recordingInfo
s['fname'] = fnamePos

# Does this recording.json file specify stage ranges, or starting points?
if isinstance(s['stages']['protoweb'], list):
    for st in s['stages']:
        if len(s['stages'][st]) == 0:
            s['stages'][st] = []
        elif not isinstance(s['stages'][st][0], list):
            s['stages'][st] = [s['stages'][st], ]
else:
    # Add the end of the recording
    a = np.load(s['fname'], mmap_mode='r')
    s['stages']['end'] = a.shape[0]

    if 'stabilimentum' in s['stages']:
        if s['stages']['stabilimentum'] >= 0:
            pass
        else:
            s['stages']['stabilimentum'] = s['stages']['end']
    else:
        s['stages']['stabilimentum'] = s['stages']['end']

    # Now convert to ranges
    s['stages']['protoweb'] = [[s['stages']['protoweb'], s['stages']['radii']],]
    s['stages']['radii'] = [[s['stages']['radii'], s['stages']['spiral_aux']],]
    s['stages']['spiral_aux'] = [[s['stages']['spiral_aux'], s['stages']['spiral_cap']],]
    s['stages']['spiral_cap'] = [[s['stages']['spiral_cap'], s['stages']['stabilimentum']],]
    s['stages']['stabilimentum'] = [[s['stages']['stabilimentum'], s['stages']['end']],]
    del s['stages']['end']

# Convert to indices used in analysis
arrIdx = np.load(fnamePos.replace('_position_orientation.npy','_abs_filt_interp_mvmt_noborder.idx.npy'))
for st in s['stages']:
    for k in range(len(s['stages'][st])):
        for m in range(2):
            s['stages'][st][k][m] = np.argmin(np.abs(np.argwhere(arrIdx).T[0] - s['stages'][st][k][m]))

In [None]:
# Load original data
arr = np.load(fnamePos)
# Subset by index
arrIdx = np.load(fnamePos.replace('_position_orientation.npy','_abs_filt_interp_mvmt_noborder.idx.npy'))
arr = arr[arrIdx,:]
# Determine center
if 'center' in recordingInfo:
    center = np.array(recordingInfo['center'])
else:
    raise Exception('Center of web should be manually specified')
# Convert position to polar coordinates relative to approximate center
r = np.linalg.norm(arr[:,0:2] - center[np.newaxis,:], axis=1)
a = np.arctan2(arr[:,0] - center[np.newaxis,0], arr[:,1] - center[np.newaxis,1])
arrPolar = np.hstack((r[:,np.newaxis], a[:,np.newaxis], arr[:,2,np.newaxis]))
# Remove noise
isNoise = np.linalg.norm(arr[:,0:2] - np.roll(arr, -1, axis=0)[:,0:2], axis=1) > 50
arr[isNoise,:] = np.nan
arrPolar[isNoise,:] = np.nan
# Compute velocities
arrPolarVel = np.roll(arrPolar, -25, axis=0) - arrPolar
# Wrap rotations into -pi / +pi
for k in [1,2]:
    for i in range(3):
        arrPolarVel[arrPolarVel[:, k] < -np.pi, k] += 2 * np.pi
    for i in range(3):
        arrPolarVel[arrPolarVel[:, k] >  np.pi, k] -= 2 * np.pi
# Filter out
isNoise |= (np.abs(arrPolarVel[:,0]) > 300) & (np.abs(arrPolarVel[:,1]) > np.pi)
#
arr[isNoise,:] = np.nan
arrPolar[isNoise,:] = np.nan
arrPolarVel = np.roll(arrPolar, -25, axis=0) - arrPolar
# Wrap rotations into -pi / +pi
for k in [1,2]:
    for i in range(3):
        arrPolarVel[arrPolarVel[:, k] < -np.pi, k] += 2 * np.pi
    for i in range(3):
        arrPolarVel[arrPolarVel[:, k] >  np.pi, k] -= 2 * np.pi
# Compute arclength data
arrPolarVelArclen = arrPolarVel.copy()
arrPolarVelArclen[:,1] *= arrPolar[:,0]
# Forward-fill NaNs
arr = pd.DataFrame(arr).fillna(method='ffill').fillna(method='bfill').values
arrPolar = pd.DataFrame(arrPolar).fillna(method='ffill').fillna(method='bfill').values
arrPolarVel = pd.DataFrame(arrPolarVel).fillna(method='ffill').fillna(method='bfill').values
arrPolarVelArclen = pd.DataFrame(arrPolarVelArclen).fillna(method='ffill').fillna(method='bfill').values

### Set color scheme

In [None]:
COLORS = [
    (0, 0, 0),
    (230, 159, 0),
#    (86, 180, 233),
    (0, 158, 115),
#    (240, 228, 66),
    (0, 114, 178),
#    (213, 94, 0),
    (204, 121, 167)
]
COLORS = [mplcol.rgb2hex(np.array(x) / 255.0) for x in COLORS]

In [None]:
COLORS_STAGES = {
    'protoweb': '#e69f00',
    'radii': '#009E73',
    'spiral_aux': '#56B4E9',
    'spiral_cap': '#CC79A7',
    'stabilimentum': '#0072B2'
}

In [None]:
# Source: https://stackoverflow.com/questions/37765197/darken-or-lighten-a-color-in-matplotlib
def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return np.clip(colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]), 0, 1)

### Plot animation

In [None]:
# Downsample trajectory
resVW = 20
xy = arr[:,0:2].copy()
xySimplIdx = simplify_coords_vw_idx(xy, resVW)
xySimpl = xy[xySimplIdx, :]

In [None]:
# Get stages
stages = {}
for st in recordingInfo['stages']:
    stages[st] = []
    for r in recordingInfo['stages'][st]:
        a0 = np.argmin(np.abs(xySimplIdx - r[0]))
        a1 = np.argmin(np.abs(xySimplIdx - r[1]))
        stages[st].append((a0, a1))
stages

In [None]:
LIMITS = [
    [0, 5],
    [-180, 180],
    [-180, 180]
]

PIXEL_TO_CM = 10.0 / 1024

In [None]:
lighten_color('#aaaaaa', 2)

In [None]:
def plotFrame(t, xySimpl, xySimplIdx, stages, arrPolar):
    darkenFactor = {
        'radii': 0.4,
        'spiral_cap': 0.9
    }
    
    # Plot
    fig = plt.figure(figsize=(15, 21))
    spec = gridspec.GridSpec(ncols=1, nrows=5, wspace=0, hspace=0.1,
        figure=fig, height_ratios=[6, 1.2, 0.6, 0.6, 0.6])
    
    # Plot overall trajectory
    ax = fig.add_subplot(spec[0, 0])
    _segmentIDs = np.hstack((0, np.linalg.norm(np.diff(xySimpl, axis=0), axis=1)))
    for stage in stages:
        for r in stages[stage]:
            _segmentIDs[r[0]] += 100
    segmentIDs = np.cumsum(_segmentIDs > 100)
    
    stageSegmentIDs = np.full(segmentIDs.size, 0, dtype=int)
    for st in stages:
        for k in stages[st]:
            stageSegmentIDs[k[0]] += 1
    stageSegmentIDs = np.cumsum(stageSegmentIDs)
    
    lim = np.argwhere(xySimplIdx <= t * xySimplIdx[-1])[-1, 0]
    lastPos = None
    
    currentStage = [(st, xySimplIdx[k[0]], stageSegmentIDs[k[0]]) for st in stages for k in stages[st] if \
        xySimplIdx[k[0]] <= (t * xySimplIdx[-1]) and xySimplIdx[k[1]] > (t * xySimplIdx[-1])]
    
    if len(currentStage) == 0:
        return None
    else:
        currentStage = currentStage[0]
        currentStageStart = currentStage[1]
        currentStageSegID = currentStage[2]
        currentStage = currentStage[0]
    
    pastStages = []
    for segmentID in np.unique(segmentIDs[:lim]):
        # Get stage
        st = ''
        _idx = np.argwhere(segmentIDs==segmentID)[0,0]
        _idxLast = np.argwhere(segmentIDs==segmentID)[-1,0]
        for stage in stages:
            for r in stages[stage]:
                if _idx >= r[0] and _idx < r[1]:
                    st = stage
                    if len(pastStages) == 0 or pastStages[-1] != stage:
                        pastStages.append(stage)
        # Plot
        ps = xySimpl[(segmentIDs==segmentID)&(xySimplIdx <= t * xySimplIdx[-1]), :].copy()
        lastPos = ps[-1,:]
        alpha = 0.0
        if st == currentStage or st in ['radii', 'spiral_cap', 'stabilimentum']:
            alpha = 1.0
        elif currentStageSegID - 1 == stageSegmentIDs[_idx]:
            alpha = 1.0 - min(1.0, ((xySimplIdx[-1]*t)-currentStageStart)/(50 * 60 * 10))
        else:
            alpha = 0.0
        numStageRepeats = len([x for x in pastStages if x == st]) - 1
        cl = COLORS_STAGES[st]
        if numStageRepeats > 0:
            cl = lighten_color(COLORS_STAGES[st], 1 + darkenFactor[st] * numStageRepeats)
        ax.plot(ps[:,0], ps[:,1], color=cl, linewidth=1, alpha=alpha)
    ax.set_aspect(1.0)
    ax.set_axis_off()
    ax.set_xlim(100, 924)
    ax.set_ylim(100, 924)
    
    # Plot current position
    if lastPos is not None:
        ax.scatter(lastPos[0], lastPos[1], s=250, color='red', marker='*', zorder=100)
    
    # Plot trajectory
    ax = fig.add_subplot(spec[1, 0])
    _segmentIDs = np.hstack((0, np.linalg.norm(np.diff(xySimpl, axis=0), axis=1)))
    for stage in stages:
        for r in stages[stage]:
            _segmentIDs[r[0]] += 100
    segmentIDs = np.cumsum(_segmentIDs > 100)

    offset = 0
    lim = np.argwhere(xySimplIdx <= t * xySimplIdx[-1])[-1, 0]
    lastPos = None
    pastStages = []
    for segmentID in np.unique(segmentIDs[:lim]):
        # Get stage
        st = ''
        _idx = np.argwhere(segmentIDs==segmentID)[0,0]
        for stage in stages:
            for ri, r in enumerate(stages[stage]):
                if _idx >= r[0] and _idx < r[1]:
                    st = stage
                    if len(pastStages) == 0 or pastStages[-1] != stage:
                        pastStages.append(stage)
        # Determine offset
        offset = {
            'protoweb': 0,
            'radii': 1024, 
            'spiral_aux': 1024 * 2,
            'spiral_cap': 1024 * 3, 
            'stabilimentum': 1024 * 4
        }[st]
        # Plot
        ps = xySimpl[(segmentIDs==segmentID)&(xySimplIdx <= t * xySimplIdx[-1]), :].copy()
        ps[:,0] += offset
        lastPos = ps[-1,:]
        numStageRepeats = len([x for x in pastStages if x == st]) - 1
        cl = COLORS_STAGES[st]
        if numStageRepeats > 0:
            cl = lighten_color(COLORS_STAGES[st], 1 + darkenFactor[st] * numStageRepeats)
        ax.plot(ps[:,0], ps[:,1], 
                color=cl, linewidth=1)
    ax.set_aspect(0.9)
    ax.set_axis_off()
    ax.set_xlim(0, 5 * 1024)
    ax.set_ylim(0, 1024)
    
    for istage, stage in enumerate(stages):
        if stage in pastStages:
            ax.text(1024 * istage + 512, 1024, {
                'protoweb': 'Protoweb',
                'radii': 'Radii',
                'spiral_aux': 'Auxiliary Spiral',
                'spiral_cap': 'Capture Spiral',
                'stabilimentum': 'Stabilimentum'
            }[stage], color='#111111', ha='center', va='center')

    # Plot current position
    if lastPos is not None:
        ax.scatter(lastPos[0], lastPos[1], s=100, color='red', marker='*')

    polarAxs = []
    # Plot polar trajectory
    ax = fig.add_subplot(spec[2, 0], clip_on=False)
    ax.plot(xySimplIdx / (50 * 60), arrPolar[xySimplIdx, 0] * PIXEL_TO_CM, color='black', linewidth=0.5)
    ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    [ax.spines[x].set_visible(False) for x in ['top', 'bottom', 'right']]
    ax.set_xlim(0, xySimplIdx[-1] / (50 * 60))
    ax.set_ylim(LIMITS[0][0], LIMITS[0][1])
    ax.set_ylabel('r (cm)')
    ax.set_yticks([0, 5])
    ax.set_yticklabels(['    0', '5'])
    polarAxs.append(ax)

    # Plot polar trajectory
    ax = fig.add_subplot(spec[3, 0], clip_on=False)
    ax.plot(xySimplIdx / (50 * 60), arrPolar[xySimplIdx, 1] * 180 / np.pi, color='black', linewidth=1)
    ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    [ax.spines[x].set_visible(False) for x in ['top', 'bottom', 'right']]
    ax.set_xlim(0, xySimplIdx[-1] / (50 * 60))
    ax.set_ylim(LIMITS[1][0], LIMITS[1][1])
    ax.set_ylabel('θ (deg)')
    ax.set_yticks([-180, 0, 180])
    ax.set_yticklabels(['-180', '  0', '180'])
    polarAxs.append(ax)

    # Plot polar trajectory
    ax = fig.add_subplot(spec[4, 0])
    ax.plot(xySimplIdx / (50 * 60), arrPolar[xySimplIdx, 2] * 180 / np.pi, color='black', linewidth=1)
    [ax.spines[x].set_visible(False) for x in ['top', 'right']]
    ax.set_xlim(0, xySimplIdx[-1] / (50 * 60))
    ax.set_ylim(LIMITS[2][0], LIMITS[2][1])
    ax.set_xlabel('Time, Without Pauses (minutes)')
    ax.set_ylabel('ω (deg)')
    ax.set_yticks([-180, 0, 180])
    ax.set_yticklabels(['-180', '  0', '180'])
    polarAxs.append(ax)

    # Mark stages
    for iax, ax in enumerate(polarAxs):
        for stage in stages:
            for numStageRepeats, r in enumerate(stages[stage]):
                col = COLORS_STAGES[stage]
                t0 = min(t, xySimplIdx[r[0]] / xySimplIdx[-1])
                t1 = min(t, xySimplIdx[r[1]] / xySimplIdx[-1])
                if numStageRepeats > 0:
                    col = lighten_color(COLORS_STAGES[stage], 1 + darkenFactor[stage] * numStageRepeats)
                if t1 > t0:
                    ax.axhspan(LIMITS[iax][0], LIMITS[iax][1], 
                        t0, t1, facecolor = col, zorder=-10, alpha=0.5)
            ax.axvline(t * xySimplIdx[-1] / (50 * 60), color='red')

    # Done
    fig.tight_layout()

    # Export to image
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    del fig
    gc.collect()
    return data

In [None]:
plt.figure(figsize=(14, 20))
plt.imshow(plotFrame(0.85, xySimpl, xySimplIdx, stages, arrPolar))

In [None]:
frames = jl.Parallel(n_jobs=50)(jl.delayed(plotFrame)(
    min(1.0, t), xySimpl, xySimplIdx, stages, arrPolar) for t in tqdm(np.linspace(0, 0.99999, 1400)))

frames = [x for x in frames if x is not None]
frames = np.array(frames + [frames[-1] for i in range(30)])

yCropTop = 0
while np.min(frames[:,yCropTop,:, :]) == 255:
    yCropTop += 1
yCropBot = frames.shape[1]-1
while np.min(frames[:,yCropBot,:, :]) == 255:
    yCropBot -= 1
xCropLeft = 0
while np.min(frames[:,:,xCropLeft, :]) == 255:
    xCropLeft += 1
xCropRight = frames.shape[2]-1
while np.min(frames[:,:,xCropRight, :]) == 255:
    xCropRight -= 1
frames = frames[:, (yCropTop-16):(yCropBot+16), (xCropLeft-16):(xCropRight+16), :]

wr = imageio.get_writer('C:/Users/acorver/Desktop/ani_polarplot_{}.mp4'.format(REC), fps=20)
for fr in range(frames.shape[0]):
    wr.append_data(frames[fr, :, :, :])
wr.close()