#### Imports

In [None]:
import os, glob, gc, regex as re, numpy as np, pandas as pd, simplification.cutil as simpl, \
    matplotlib.pyplot as plt, joblib as jl, json, time, matplotlib.gridspec as gridspec, \
    matplotlib.colors as mplcol, colorcet, miniball, matplotlib.ticker as ticker
from tqdm import tqdm_notebook as tqdm

# Source: https://stackoverflow.com/questions/2158395/flatten-an-irregular-list-of-lists
flatten = lambda *n: (e for a in n
    for e in (flatten(*a) if isinstance(a, (tuple, list)) else (a,)))

os.getpid()

#### Set color scheme for plotting

In [None]:
# Wong colorblind-safe palette
# Source: https://www.nature.com/articles/nmeth.1618
# Source: https://davidmathlogic.com/colorblind/#%23648FFF-%23785EF0-%23DC267F-%23FE6100-%23FFB000
COLORS = [
    (0, 0, 0),
    (230, 159, 0),
    (0, 158, 115),
    (0, 114, 178),
    (204, 121, 167),
    (86, 180, 233),
    (240, 228, 66),
    (213, 94, 0)
]
COLORS = [mplcol.rgb2hex(np.array(x) / 255.0) for x in COLORS]

#### Load HHMM data

In [None]:
# Load polarpos data, and accompanying HHMM predictions
#fnames = glob.glob('Y:/wavelet/hhmm-results/*regimes_12minrun_manuallabels_5fold.*.resave.pickle')
fnames = glob.glob('Y:/wavelet/hhmm-results/*regimes_12minrun_manuallabels_5fold.pickle')
fnames = [x for x in fnames if not 'idxmapping' in x]
fnames = [fnames[5],]
fnames

In [None]:
t0 = time.time()

gc.collect()

def loadPickle(fn):
    try:
        gc.disable()
        return jl.load(fn)
        gc.enable()
    except Exception as e:
        print(fn, str(e))
        return None
    
data = [loadPickle(fn) for fn in tqdm(fnames)]
data = [x for x in data if x is not None]

t1 = time.time()

#### Load position data

In [None]:
def loadPositionData(modelIdx, modelRepIdx, modelRecIdx):
    def loadJSON(x):
        if os.path.exists(x):
            with open(x, 'r') as f:
                return json.load(f)
        else:
            return None

    # Load position/orientation data
    fnamePos = glob.glob(os.path.abspath(os.path.join(os.path.dirname(
        data[modelIdx]['fnames'][modelRecIdx][0]), '../croprot/*dlc_position_orientation.npy')))[0]

    # Load recording info
    fnameRecInfo = os.path.join(os.path.dirname(os.path.dirname(fnamePos)), 'recording.json')
    recordingInfo = loadJSON(fnameRecInfo)

    # Create short name
    fnamesShort = re.search('^Z:/.*/(.*)/.*/.*$', fnamePos.replace('\\', '/')).group(1)

    # Fill in missing stage information, if necessary
    s = recordingInfo
    s['fname'] = fnamePos

    # Does this recording.json file specify stage ranges, or starting points?
    for st in s['stages']:
        if s['stages'][st] == []:
            s['stages'][st] = []
        elif not isinstance(s['stages'][st][0], list):
            s['stages'][st] = [s['stages'][st], ]

    # Convert to indices used in analysis
    arrIdx = np.load(fnamePos.replace('_position_orientation.npy','_abs_filt_interp_mvmt_noborder.idx.npy'))
    for st in s['stages']:
        for k in range(len(s['stages'][st])):
            for m in range(2):
                s['stages'][st][k][m] = np.argmin(np.abs(np.argwhere(arrIdx).T[0] - s['stages'][st][k][m]))
    
    # Load original data
    arr = np.load(fnamePos)

    # Subset by index
    arrIdx = np.load(fnamePos.replace('_position_orientation.npy',
                                      '_abs_filt_interp_mvmt_noborder.idx.npy'))
    arr = arr[arrIdx,:]
    
    # Done
    return arr, recordingInfo

#### Collect HHMM regime probabilities and compute polar position data

In [None]:
def computeData(modelIdx, modelRepIdx, modelRecIdx):
    # Load position data
    arr, recordingInfo = loadPositionData(modelIdx, modelRepIdx, modelRecIdx)
    
    # Load index mapping
    arrIdxMapping = jl.load(fnames[modelIdx].replace(re.search('([.0-9]*\\.pickle)$', 
        fnames[modelIdx]).group(0), '.idxmapping.pickle').replace('.resave', '').replace('.1', ''))

    # Load raw regime probabilities
    try:
        d = data[modelIdx]['models'][modelRepIdx]['statesPredProb'][modelRecIdx].copy()
    except:
        d = data[modelIdx]['models'][modelRepIdx]['model'].predict_log_proba(
            data[modelIdx]['models'][modelRepIdx]['model-fit-states'][modelRecIdx]).copy()
    regimeIDs = np.array([(int(x.group(1)) if x is not None else -1) for x in [re.search('^r([0-9]*)_', x.name) \
        for x in data[modelIdx]['models'][modelRepIdx]['model'].states]])

    probRegimes = np.zeros((d.shape[0], data[modelIdx]['numRegimesHHMM']))
    for regimeID in range(probRegimes.shape[1]):
        probRegimes[:, regimeID] = np.nanmax(d[:, np.argwhere(regimeIDs == regimeID)[:,0]], axis=1)
    probRegimes = np.exp(probRegimes)

    # Reshape array to re-introduce duplicate states using the index loaded above
    probRegimes = np.array([(probRegimes[i,:] if i >= 0 else np.full(
        probRegimes.shape[1], np.nan)) for i in arrIdxMapping[modelRecIdx]])
    
    for regimeID in range(probRegimes.shape[1]):
        probRegimes[:, regimeID] = pd.DataFrame(probRegimes[:, regimeID]).fillna(
            method='ffill').fillna(method='bfill').values[:,0]

    # Smooth and stack regime probabilities
    probRegimesSmoothed = probRegimes.copy()
    for j in range(probRegimes.shape[1]):
        probRegimesSmoothed[:,j] = pd.DataFrame(probRegimes[:,j]).rolling(window=3000).max().values[:,0]
    probRegimesSmoothed /= np.sum(probRegimesSmoothed, axis=1)[:,np.newaxis]

    probRegimesSmoothedStacked = np.zeros((probRegimesSmoothed.shape[0], probRegimesSmoothed.shape[1] + 1))
    for j in range(probRegimesSmoothed.shape[1]):
        probRegimesSmoothedStacked[:,j+1] = \
            probRegimesSmoothedStacked[:,j] + probRegimesSmoothed[:,j]

    # Determine center
    if 'center' in recordingInfo:
        center = np.array(recordingInfo['center'])
    else:
        print(recordingInfo)
        raise Exception('Center of web should be manually specified')
    # Convert position to polar coordinates relative to approximate center
    r = np.linalg.norm(arr[:,0:2] - center[np.newaxis,:], axis=1)
    a = np.arctan2(arr[:,0] - center[np.newaxis,0], arr[:,1] - center[np.newaxis,1])
    arrPolar = np.hstack((r[:,np.newaxis], a[:,np.newaxis], arr[:,2,np.newaxis]))
    # Remove noise
    isNoise = np.linalg.norm(arr[:,0:2] - np.roll(arr, -1, axis=0)[:,0:2], axis=1) > 50
    arr[isNoise,:] = np.nan
    arrPolar[isNoise,:] = np.nan
    # Compute velocities
    arrPolarVel = np.roll(arrPolar, -25, axis=0) - arrPolar
    # Wrap rotations into -pi / +pi
    for k in [1,2]:
        for i in range(3):
            arrPolarVel[arrPolarVel[:, k] < -np.pi, k] += 2 * np.pi
        for i in range(3):
            arrPolarVel[arrPolarVel[:, k] >  np.pi, k] -= 2 * np.pi
    # Filter out
    isNoise |= (np.abs(arrPolarVel[:,0]) > 300) & (np.abs(arrPolarVel[:,1]) > np.pi)
    #
    arr[isNoise,:] = np.nan
    arrPolar[isNoise,:] = np.nan
    arrPolarVel = np.roll(arrPolar, -25, axis=0) - arrPolar
    # Wrap rotations into -pi / +pi
    for k in [1,2]:
        for i in range(3):
            arrPolarVel[arrPolarVel[:, k] < -np.pi, k] += 2 * np.pi
        for i in range(3):
            arrPolarVel[arrPolarVel[:, k] >  np.pi, k] -= 2 * np.pi
            
    # Done
    return arr, arrPolar, probRegimes, probRegimesSmoothedStacked, recordingInfo

In [None]:
# Source: https://stackoverflow.com/questions/37765197/darken-or-lighten-a-color-in-matplotlib
def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return np.clip(colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]), 0, 1)

#### Plot spider position and regime probabilities

In [None]:
data[0]['fnames'][4]

In [None]:
def plot(modelIdx, modelRepIdx, modelRecIdx, resVW_w = 50, resVW_hhmm = 0.001, regimeColors = None, regimeOrder = None):
    if isinstance(modelRecIdx, str):
        modelRecIdx = [xi for xi, x in enumerate(data[0]['fnames']) if modelRecIdx in x[0]][0]
        print('Model rec index: {}'.format(modelRecIdx))
    
    arr, arrPolar, probRegimes, probRegimesSmoothedStacked, recordingInfo = \
        computeData(modelIdx, modelRepIdx, modelRecIdx)
    
    fig = plt.figure(figsize=(24, 6))
    spec = gridspec.GridSpec(ncols=4, nrows=4, width_ratios=[2.5, 0.5, 0.5, 0.5])

    PIXEL_TO_CM = 10.0 / 1024
    LIMITS = [[0, 6], [-180, 180]]
    
    # Draw polar pos
    for pp in range(2):
        t = np.arange(arrPolar.shape[0]) / (60 * 50)
        ax1 = fig.add_subplot(spec[pp, 0])
        
        # Simplify timeseries before plotting
        # -- Keep doubling the resVW factor until the number of points is reduced to 1% of original
        resVW = 0.000001
        while True:
            coords = np.hstack((
                np.arange(arrPolar.shape[0], dtype=np.float)[:,np.newaxis] / (60 * 50), 
                arrPolar[:,pp,np.newaxis] * PIXEL_TO_CM if pp == 0 else (arrPolar[:,pp,np.newaxis] * 180 / np.pi)))
            
            coords = coords[~np.any(np.isnan(coords), axis=1)]
            coordsSimpl = simpl.simplify_coords_vw(coords, resVW)
            if coords.shape[0] / coordsSimpl.shape[0] > 100 or coordsSimpl.shape[0] < 10000:
                coords = coordsSimpl
                break
            else:
                resVW *= 2
        
        ax1.plot(coords[:,0], coords[:,1], color='#444444', linewidth=0.5)
        ax1.set_ylim(LIMITS[pp][0], LIMITS[pp][1])
        
        if pp == 0:
            ax1.set_ylabel('Distance from Hub\n(cm)')
        else:
            ax1.set_ylabel('Angular Position\n(Degrees)')
            ax1.set_yticks([-180, -90, 0, 90, 180])
            ax1.yaxis.set_major_formatter(ticker.FormatStrFormatter("%d°"))
        
        ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        
        # Draw stage transitions
        imax = np.max(list(flatten([st for st in recordingInfo['stages'].values()])))
        for istage, stage in enumerate(recordingInfo['stages']):
            # Determine what frame ranges belong to this stage
            for i0, i1 in recordingInfo['stages'][stage]:
                ax1.axhspan(LIMITS[pp][0], LIMITS[pp][0] + (LIMITS[pp][1] - LIMITS[pp][0]), 
                    i0 / imax, i1 / imax, 
                        facecolor = lighten_color(COLORS[1+istage], 0.15), zorder=-1000)
                ax1.axvline(i1 / (60 * 50), color='#111111', linestyle='--')
                
        ax1.set_xlim(0, imax / (60 * 50))
    
    # Draw HHMM regime
    ax3 = fig.add_subplot(spec[2:4, 0])
    
    y01IdxAll = np.full(probRegimesSmoothedStacked.shape[0], False, dtype=np.bool)
    for regimeID in range(probRegimes.shape[1]):
        y01 = pd.DataFrame(np.hstack((
            np.arange(probRegimesSmoothedStacked.shape[0])[:,np.newaxis],
            probRegimesSmoothedStacked[:,regimeID, np.newaxis]))).fillna(
            method='ffill').fillna(method='bfill').values.copy()
        y01Idx = simpl.simplify_coords_vw_idx(y01, resVW_hhmm)
        y01IdxAll[y01Idx] = True
        #print('Number of timepoints kept: {}'.format(np.sum(y01IdxAll)))
    
    for regimeID in range(probRegimes.shape[1]):
        y01 = pd.DataFrame(probRegimesSmoothedStacked[:,[regimeID, regimeID+1]]).fillna(
                method='ffill').fillna(method='bfill').values.copy()
        ax3.fill_between(t[y01IdxAll], 
                         y01[y01IdxAll, 0], y01[y01IdxAll, 1], 
                         color=regimeColors[regimeID], alpha=0.5)
        ax3.plot(t[y01IdxAll], y01[y01IdxAll, 1], 
                 color='#444444', linewidth=1, zorder=10)
    
    ax3.set_xlabel('Time (Minutes)')
    ax3.set_ylabel('Prob(Regime)')
    ax3.set_xlim(0, imax / (60 * 50))
    ax3.set_ylim(0, 1.0)
        
    # Draw HHMM trajectories
    for regimeID in range(probRegimes.shape[1]):
        nrow = min(regimeOrder[regimeID], 5) % 2
        ncol = int(min(regimeOrder[regimeID], 5) / 2)
        ax4 = fig.add_subplot(spec[(nrow*2):(nrow*2+2), ncol+1])
        
        xy = arr[np.argmax(probRegimes, axis=1) == regimeID, 0:2].copy()
        xy = pd.DataFrame(xy).fillna(method='ffill').fillna(method='bfill').values.copy()
        
        xy = simpl.simplify_coords_vw(xy, resVW_w)
        segmentIDs = np.cumsum(np.linalg.norm(xy - np.roll(xy, 1, axis=0), axis=1) > 100)

        for sid in np.unique(segmentIDs):
            ax4.plot(xy[segmentIDs==sid, 0], xy[segmentIDs==sid, 1], 
                     color=regimeColors[regimeID], 
                     linewidth=1)
        
        # Resize web to smallest square
        if xy.shape[0] > 0:
            xyMin = np.nanmin(arr[:,0:2], axis=0)
            xyMax = np.nanmax(arr[:,0:2], axis=0)
            xyWidth  = xyMax[0] - xyMin[0]
            xyHeight = xyMax[1] - xyMin[1]
            xyCent = xyMin * 0.5 + xyMax * 0.5
            xySpan = max(xyWidth, xyHeight) / 2
            ax4.set_xlim(xyCent[0] - xySpan, xyCent[0] + xySpan)
            ax4.set_ylim(xyCent[1] - xySpan, xyCent[1] + xySpan)
        
        ax4.set_axis_off()

    fig.savefig('C:/Users/acorver/Desktop/paper-figures/Fig_4g_{}_{}_{}_{}.pdf'.format(
        data[modelIdx]['numRegimesHHMM'], modelIdx, modelRepIdx, modelRecIdx))

In [None]:
for recID in range(21):
    plot(0, 9, recID, regimeColors=[
        COLORS[1], 
        COLORS[5], 
        COLORS[3], 
        COLORS[2], 
        COLORS[4] 
    ], regimeOrder=[0,4,3,2,1])