#### Imports

In [2]:
import os, glob, gc, regex as re, numpy as np, pandas as pd, simplification.cutil as simpl, \
    matplotlib.pyplot as plt, joblib as jl, json, time, matplotlib.gridspec as gridspec, \
    matplotlib.colors as mplcol, colorcet, miniball, matplotlib.ticker as ticker
from matplotlib.patches import FancyBboxPatch
from tqdm import tqdm_notebook as tqdm

os.getpid()

15412

#### Set color scheme for plotting

In [3]:
# Wong colorblind-safe palette
# Source: https://www.nature.com/articles/nmeth.1618
# Source: https://davidmathlogic.com/colorblind/#%23648FFF-%23785EF0-%23DC267F-%23FE6100-%23FFB000
COLORS = [
    (0, 0, 0),
    (230, 159, 0),
    (0, 158, 115),
    (0, 114, 178),
    (204, 121, 167),
    (86, 180, 233),
    (240, 228, 66),
    (213, 94, 0)
]
COLORS = [mplcol.rgb2hex(np.array(x) / 255.0) for x in COLORS]

#### Load HHMM data

In [4]:
# Load polarpos data, and accompanying HHMM predictions
#fnames = glob.glob('Y:/wavelet/hhmm-results/*regimes_12minrun_manuallabels_5fold.*.resave.pickle')
fnames = glob.glob('Y:/wavelet/hhmm-results/*regimes_12minrun_manuallabels_5fold.pickle')
fnames = [x for x in fnames if not 'idxmapping' in x]
fnames = [fnames[5],]
fnames

['Y:/wavelet/hhmm-results\\5regimes_12minrun_manuallabels_5fold.pickle']

In [None]:
t0 = time.time()

gc.collect()

def loadPickle(fn):
    try:
        gc.disable()
        return jl.load(fn)
        gc.enable()
    except Exception as e:
        print(fn, str(e))
        return None
    
data = [loadPickle(fn) for fn in tqdm(fnames)]
data = [x for x in data if x is not None]

t1 = time.time()

#### Load position data

In [9]:
def loadPositionData(modelIdx, modelRepIdx, modelRecIdx):
    def loadJSON(x):
        if os.path.exists(x):
            with open(x, 'r') as f:
                return json.load(f)
        else:
            return None

    # Load position/orientation data
    fnamePos = glob.glob(os.path.abspath(os.path.join(os.path.dirname(
        data[modelIdx]['fnames'][modelRecIdx][0]), '../croprot/*dlc_position_orientation.npy')))[0]

    # Load recording info
    fnameRecInfo = os.path.join(os.path.dirname(os.path.dirname(fnamePos)), 'recording.json')
    recordingInfo = loadJSON(fnameRecInfo)

    # Create short name
    fnamesShort = re.search('^Z:/.*/(.*)/.*/.*$', fnamePos.replace('\\', '/')).group(1)

    # Fill in missing stage information, if necessary
    s = recordingInfo
    s['fname'] = fnamePos

    # Does this recording.json file specify stage ranges, or starting points?
    if isinstance(s['stages']['protoweb'], list):
        for st in s['stages']:
            try:
                if not isinstance(s['stages'][st][0], list):
                    s['stages'][st] = [s['stages'][st], ]
            except:
                s['stages'][st] = []
    else:
        # Add the end of the recording
        a = np.load(s['fname'], mmap_mode='r')
        s['stages']['end'] = a.shape[0]

        if 'stabilimentum' in s['stages']:
            if s['stages']['stabilimentum'] >= 0:
                pass
            else:
                s['stages']['stabilimentum'] = s['stages']['end']
        else:
            s['stages']['stabilimentum'] = s['stages']['end']

        # Now convert to ranges
        s['stages']['protoweb'] = [[s['stages']['protoweb'], s['stages']['radii']],]
        s['stages']['radii'] = [[s['stages']['radii'], s['stages']['spiral_aux']],]
        s['stages']['spiral_aux'] = [[s['stages']['spiral_aux'], s['stages']['spiral_cap']],]
        s['stages']['spiral_cap'] = [[s['stages']['spiral_cap'], s['stages']['stabilimentum']],]
        s['stages']['stabilimentum'] = [[s['stages']['stabilimentum'], s['stages']['end']],]
        del s['stages']['end']

    # Convert to indices used in analysis
    arrIdx = np.load(fnamePos.replace('_position_orientation.npy','_abs_filt_interp_mvmt_noborder.idx.npy'))
    for st in s['stages']:
        for k in range(len(s['stages'][st])):
            for m in range(2):
                s['stages'][st][k][m] = np.argmin(np.abs(np.argwhere(arrIdx).T[0] - s['stages'][st][k][m]))
    
    # Load original data
    arr = np.load(fnamePos)

    # Subset by index
    arrIdx = np.load(fnamePos.replace('_position_orientation.npy',
                                      '_abs_filt_interp_mvmt_noborder.idx.npy'))
    arr = arr[arrIdx,:]
    
    # Done
    return arr, recordingInfo

#### Collect HHMM regime probabilities and compute polar position data

In [131]:
def computeData(modelIdx, modelRepIdx, modelRecIdx):
    # Load position data
    arr, recordingInfo = loadPositionData(modelIdx, modelRepIdx, modelRecIdx)
    
    # Load index mapping
    arrIdxMapping = jl.load(fnames[modelIdx].replace(re.search('([.0-9]*\\.pickle)$', 
        fnames[modelIdx]).group(0), '.idxmapping.pickle').replace('.resave', '').replace('.1', ''))

    # Load raw regime probabilities
    try:
        d = data[modelIdx]['models'][modelRepIdx]['statesPredProb'][modelRecIdx].copy()
    except:
        d = data[modelIdx]['models'][modelRepIdx]['model'].predict_log_proba(
            data[modelIdx]['models'][modelRepIdx]['model-fit-states'][modelRecIdx]).copy()
    regimeIDs = np.array([(int(x.group(1)) if x is not None else -1) for x in [re.search('^r([0-9]*)_', x.name) \
        for x in data[modelIdx]['models'][modelRepIdx]['model'].states]])

    probRegimes = np.zeros((d.shape[0], data[modelIdx]['numRegimesHHMM']))
    for regimeID in range(probRegimes.shape[1]):
        probRegimes[:, regimeID] = np.nanmax(d[:, np.argwhere(regimeIDs == regimeID)[:,0]], axis=1)
    probRegimes = np.exp(probRegimes)

    # Reshape array to re-introduce duplicate states using the index loaded above
    probRegimes = np.array([(probRegimes[i,:] if i >= 0 else np.full(
        probRegimes.shape[1], np.nan)) for i in arrIdxMapping[modelRecIdx]])
    
    for regimeID in range(probRegimes.shape[1]):
        probRegimes[:, regimeID] = pd.DataFrame(probRegimes[:, regimeID]).fillna(
            method='ffill').fillna(method='bfill').values[:,0]
    
    # Smooth and stack regime probabilities
    probRegimesSmoothed = probRegimes.copy()
    for j in range(probRegimes.shape[1]):
        probRegimesSmoothed[:,j] = pd.DataFrame(probRegimes[:,j]).rolling(window=500).max().values[:,0]
    probRegimesSmoothed /= np.sum(probRegimesSmoothed, axis=1)[:,np.newaxis]

    # Done
    return arr, probRegimes, probRegimesSmoothed, recordingInfo

#### Plot spider position and regime probabilities

In [None]:
modelIdx = 0

for modelRepIdx in tqdm(range(len(data[modelIdx]['models']))):
    Nrec = len(data[modelIdx]['fnames'])
    d = [computeData(modelIdx, modelRepIdx, modelRecIdx) for modelRecIdx in tqdm(range(Nrec), leave=False)]

    examples = np.arange(21, dtype=int)

    REGIME_REMAP = [0, 2, 4, 1, 3]

    # Plot
    resVW_wAll = 500
    resVW_w = 100
    nrows = 6
    nRegimes = 5
    ncols = int(np.ceil(len(examples) / nRegimes)) * nRegimes
    fig, ax = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2.5))
    fig.subplots_adjust(hspace=-0.5, wspace=-0.5)

    axBg = fig.add_axes([0,0,1,1])
    axBg.xaxis.set_visible(False)
    axBg.yaxis.set_visible(False)
    axBg.set_axis_off()
    axBg.set_zorder(-1000)

    for i in range(nrows):
        for j in range(ncols):
            ax[i][j].set_axis_off()
            ax[i][j].set_aspect('equal')

    for i, ex in tqdm(enumerate(examples), leave=False):
        arr, probRegimes = d[ex][0], d[ex][1]
        for regimeID in range(probRegimes.shape[1]):
            r = int(i%nrows)
            c = int(i/nrows) * nRegimes + REGIME_REMAP[regimeID]

            if regimeID == 0:
                axBg.add_patch(FancyBboxPatch(
                    (c * 0.982 / ncols + 0.008, r * 0.982 / nrows - 0.002), 0.331, 0.258,
                    boxstyle="round,pad=-0.0040,rounding_size=0.007",
                    ec="#aaaaaa", fc='#fbfbfb', clip_on=False,
                    mutation_aspect=ncols / nrows))

            xy = pd.DataFrame(arr[:, 0:2]).fillna(method='ffill').fillna(method='bfill').values.copy()
            xy = simpl.simplify_coords_vw(xy, resVW_wAll)
            segmentIDs = np.cumsum(np.linalg.norm(xy - np.roll(xy, 1, axis=0), axis=1) > 100)

            for sid in np.unique(segmentIDs):
                ax[r][c].plot(
                    xy[segmentIDs==sid, 0], xy[segmentIDs==sid, 1], 
                    color='#bbbbbb', linewidth=0.5)

            xy = arr[np.argmax(probRegimes, axis=1) == regimeID, 0:2].copy()
            xy = pd.DataFrame(xy).fillna(method='ffill').fillna(method='bfill').values.copy()
            xy = simpl.simplify_coords_vw(xy, resVW_w)
            segmentIDs = np.cumsum(np.linalg.norm(xy - np.roll(xy, 1, axis=0), axis=1) > 100)

            for sid in np.unique(segmentIDs):
                ax[r][c].plot(
                    xy[segmentIDs==sid, 0], xy[segmentIDs==sid, 1], 
                    color=COLORS[REGIME_REMAP[regimeID]+1], linewidth=1)
            
            ax[r][c].set_axis_off()
            ax[r][c].set_xlim(100, 924)
            ax[r][c].set_ylim(100, 924)

    fig.tight_layout()
    fig.savefig('C:/Users/acorver/Desktop/paper-figures/Fig_4f_{}_{}.pdf'.format(
        data[modelIdx]['numRegimesHHMM'], modelRepIdx))

### Supplementary Figure

In [None]:
modelIdx = 0
modelRecIdx = 10

Nrec = len(data[modelIdx]['fnames'])
d = [computeData(modelIdx, modelRepIdx, modelRecIdx) for modelRepIdx in tqdm(
    range(len(data[modelIdx]['models'])), leave=False)]

examples = np.arange(len(d), dtype=int)

REGIME_REMAP = np.arange(data[modelIdx]['numRegimesHHMM'], dtype=int)

# Plot
resVW_wAll = 500
resVW_w = 100
nrows = 6
nRegimes = data[modelIdx]['numRegimesHHMM']
ncols = int(np.ceil(len(examples) / nRegimes)) * nRegimes
fig, ax = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2.5))
fig.subplots_adjust(hspace=-0.5, wspace=-0.5)

axBg = fig.add_axes([0,0,1,1])
axBg.xaxis.set_visible(False)
axBg.yaxis.set_visible(False)
axBg.set_axis_off()
axBg.set_zorder(-1000)

for i in range(nrows):
    for j in range(ncols):
        ax[i][j].set_axis_off()
        ax[i][j].set_aspect('equal')

for i, ex in tqdm(enumerate(examples), leave=False):
    arr, probRegimes = d[ex][0], d[ex][1]
    for regimeID in range(probRegimes.shape[1]):
        r = int(i%nrows)
        c = int(i/nrows) * nRegimes + REGIME_REMAP[regimeID]

        if regimeID == 0:
            axBg.add_patch(FancyBboxPatch(
                (c * 0.982 / ncols + 0.008, r * 0.982 / nrows - 0.002), 0.331, 0.258,
                boxstyle="round,pad=-0.0040,rounding_size=0.007",
                ec="#aaaaaa", fc='#fbfbfb', clip_on=False,
                mutation_aspect=ncols / nrows))

        xy = pd.DataFrame(arr[:, 0:2]).fillna(method='ffill').fillna(method='bfill').values.copy()
        xy = simpl.simplify_coords_vw(xy, resVW_wAll)
        segmentIDs = np.cumsum(np.linalg.norm(xy - np.roll(xy, 1, axis=0), axis=1) > 100)

        for sid in np.unique(segmentIDs):
            ax[r][c].plot(
                xy[segmentIDs==sid, 0], xy[segmentIDs==sid, 1], 
                color='#bbbbbb', linewidth=0.5)

        xy = arr[np.argmax(probRegimes, axis=1) == regimeID, 0:2].copy()
        xy = pd.DataFrame(xy).fillna(method='ffill').fillna(method='bfill').values.copy()
        xy = simpl.simplify_coords_vw(xy, resVW_w)
        segmentIDs = np.cumsum(np.linalg.norm(xy - np.roll(xy, 1, axis=0), axis=1) > 100)

        for sid in np.unique(segmentIDs):
            ax[r][c].plot(
                xy[segmentIDs==sid, 0], xy[segmentIDs==sid, 1], 
                color=COLORS[REGIME_REMAP[regimeID]+1], linewidth=1)

        ax[r][c].set_axis_off()
        ax[r][c].set_xlim(100, 924)
        ax[r][c].set_ylim(100, 924)

fig.tight_layout()
fig.savefig('C:/Users/acorver/Desktop/paper-figures/Fig_Suppl_4X_{}_{}.pdf'.format(
    data[modelIdx]['numRegimesHHMM'], modelRecIdx))

### Misc. statistics, move to separate script

In [166]:
numStatesAll = np.array([x.size for x in data[0]['models'][0]['model-fit-states']])
numTransAll = numStatesAll - 1
numTransAll

array([11109, 11645, 20964, 22734, 13733,  9247,  9879, 11882,  6567,
       11434, 11460, 15680, 13135,  7427, 13254,  7798,  7269, 14795,
       12946, 15854,  9955])

In [167]:
np.mean(numTransAll), np.std(numTransAll)

(12322.238095238095, 4051.0586377103705)

In [169]:
np.percentile(numTransAll, 10), np.percentile(numTransAll, 90)

(7427.0, 15854.0)