In [None]:
import numpy as np, os, glob, matplotlib.pyplot as plt, pandas as pd, colorcet, matplotlib as mpl
from tqdm import tqdm_notebook as tqdm
import miniball
from scipy.ndimage import binary_dilation, binary_erosion
from skimage.measure import label
import shapely.ops, shapely.geometry as geom

In [None]:
COLORS_CLUSTERS = {
 'walk': '#0072b2',
 'one-leg-after-other': '#eee8c7',
 'extrude-slow': '#f0e442',
 'extrude': '#009e73',
 'left-leg': '#e69f00',
 'both-legs': '#ccc8a7',
 'stationary': '#000000',
 'stationary-posterior': '#666666',
 'stationary-anterior': '#666666',
 'stabilimentum': '#cc79a7',
 'noisy': None,
 'right-leg': '#56b4e9',
 'bend-abdomen': '#d55e00'
}

In [None]:
fnameP = 'Z:/behavior/*/wavelet/rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.clusters.npy'
fnameA = 'rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.clusters.npy'

fnamesP = [x for x in glob.glob(fnameP) if 'RIG' not in x]
fnamesA = [os.path.join(os.path.dirname(x), fnameA) for x in fnamesP]
len(fnamesP)

In [None]:
MERGE_MANUAL = True

In [None]:
arrClustersP = [np.load(x)[:,0] for x in fnamesP]
arrClustersA = [np.load(x)[:,0] for x in fnamesA]

In [None]:
arrUsetsP = [np.all(~np.isnan(np.load(x.replace('.clusters.npy', '.npy'))), axis=1) for x in fnamesP]
arrUsetsA = [np.all(~np.isnan(np.load(x.replace('.clusters.npy', '.npy'))), axis=1) for x in fnamesA]

In [None]:
arrTsneP = [np.load(x.replace('.clusters.npy', '.filtered2.npy'))[:, 2:4] for x in fnamesP]
arrTsneA = [np.load(x.replace('.clusters.npy', '.filtered2.npy'))[:, 2:4] for x in fnamesA]

In [None]:
def toNumber(x):
    try:
        return int(x)
    except:
        return -1
    
def loadLabels(fnameLabels):
    txtClusterLabels = ''
    with open(fnameLabels, 'r') as f:
        txtClusterLabels = f.read()
    clusterLabels = {}
    curLabel = ''
    for line in txtClusterLabels.split('\n'):
        if ':' in line:
            curLabel = line[:line.find(':')]
        elif len(line.strip()) > 0:
            clusterLabels[curLabel] = [toNumber(x) for x in line.split(',') if toNumber(x) >= 0]
    return clusterLabels

ANTERIOR = 1
POSTERIOR = 0

fnameClusterLabelsA = '\\\\?\\Y:\\wavelet\\clips\\rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd\\cluster_names.txt'
fnameClusterLabelsP = '\\\\?\\Y:\\wavelet\\clips\\rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd\\cluster_names.txt'
clusterLabelsAmanual = loadLabels(fnameClusterLabelsA)
clusterLabelsPmanual = loadLabels(fnameClusterLabelsP)
del clusterLabelsAmanual['noisy']
del clusterLabelsPmanual['noisy']

In [None]:
fnameSharedTsneP = '\\\\?\\Y:\\wavelet\\rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd\\rawmvmt_dlc_euclidean-midline_no-abspos_no-vel_00000000010001000000010001_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.smoothed.watershed.npy'
fnameSharedTsneA = '\\\\?\\Y:\\wavelet\\rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd\\rawmvmt_dlc_euclidean_no-abspos_no-vel_00000010001000000010001000_60_16_meansub_scalestd_hipow_tsne_no-pca_perplexity_100_200000_2000_euclidean.smoothed.watershed.npy'
arrSharedTsneP = np.load(fnameSharedTsneP)[:, :, 0, 1].astype(int)
arrSharedTsneA = np.load(fnameSharedTsneA)[:, :, 0, 1].astype(int)

In [None]:
def splitManualLabelsFromContiguity(clusterLabelsManual, arrSharedTsne, mergeManual=True):
    clusterLabels = {}
    for key in clusterLabelsManual:
        for clid in clusterLabelsManual[key]:
            clusterLabels['{}-{}'.format(key, clid)] = (arrSharedTsne == clid, [clid, ])

    if mergeManual:
        # Now aggregate contiguous clusters
        #    (keep aggregating until no changes occur)
        while True:
            match = False
            for key1 in clusterLabels:
                for key2 in clusterLabels:
                    if key1 != key2:
                        base1 = key1[:key1.rfind('-')]
                        base2 = key2[:key2.rfind('-')]
                        if base1 == base2:
                            if np.sum(clusterLabels[key1][0] & binary_dilation(clusterLabels[key2][0], iterations=1)) > 0:
                                match = True
                                clusterLabels[key1] = (clusterLabels[key1][0] | clusterLabels[key2][0], 
                                                       clusterLabels[key1][1] + clusterLabels[key2][1])
                                del clusterLabels[key2]
                                break
                if match:
                    break
            if not match:
                break
            
    return clusterLabels, {k:v[1] for k, v in clusterLabels.items()}

clusterLabelsAcl, clusterLabelsA = splitManualLabelsFromContiguity(
    clusterLabelsAmanual, arrSharedTsneA, mergeManual=MERGE_MANUAL)
clusterLabelsPcl, clusterLabelsP = splitManualLabelsFromContiguity(
    clusterLabelsPmanual, arrSharedTsneP, mergeManual=MERGE_MANUAL)

In [None]:
# Convert cluster IDs to manual cluster IDs
arrClustersAmanual = [[([i for i, k in enumerate(list(clusterLabelsA.keys())) if \
    int(y) in clusterLabelsA[k]] + [np.nan,])[0] for y in x] for x in tqdm(arrClustersA, leave=False)]
arrClustersPmanual = [[([i for i, k in enumerate(list(clusterLabelsP.keys())) if \
    int(y) in clusterLabelsP[k]] + [np.nan,])[0] for y in x] for x in tqdm(arrClustersP, leave=False)]

arrClustersAmanualNoNA = [pd.DataFrame(x).fillna(method='ffill').fillna(
    method='bfill').values[:,0].astype(int) for x in arrClustersAmanual]
arrClustersPmanualNoNA = [pd.DataFrame(x).fillna(method='ffill').fillna(
    method='bfill').values[:,0].astype(int) for x in arrClustersPmanual]

arrClustersAmanual = [np.array([(-1 if np.isnan(y) else y) for y in x]).astype(int) for x in arrClustersAmanual]
arrClustersPmanual = [np.array([(-1 if np.isnan(y) else y) for y in x]).astype(int) for x in arrClustersPmanual]

In [None]:
def isEmbeddingStable(x):
    a = [[],]
    for i, c in enumerate(x):
        if len(a[-1]) == 0 or c == a[-1][-1]:
            a[-1] = a[-1] + [c,]
        else:
            a.append([c,])
    a = [np.full(len(x), len(x)>=12, dtype=np.bool) for x in a]
    return np.hstack(a)
arrStableA = [isEmbeddingStable(x) for x in tqdm(arrClustersAmanual, leave=False)]
arrStableP = [isEmbeddingStable(x) for x in tqdm(arrClustersPmanual, leave=False)]

In [None]:
def keepStable(x, st):
    x2 = x.copy().astype(np.float64)
    x2[~st] = np.nan
    x2[x2 < 0] = np.nan
    return x2

arrClustersAmanualStableNoNA = [pd.DataFrame(keepStable(x, st)).fillna(method='ffill').fillna(
    method='bfill').values[:,0].astype(int) for x, st in zip(arrClustersAmanual, arrStableA)]
arrClustersPmanualStableNoNA = [pd.DataFrame(keepStable(x, st)).fillna(method='ffill').fillna(
    method='bfill').values[:,0].astype(int) for x, st in zip(arrClustersPmanual, arrStableP)]

In [None]:
idxJumpA = [np.argwhere(np.diff(x) != 0)[:,0] for x in arrClustersAmanualNoNA]
idxJumpStableA = [np.argwhere(np.diff(x) != 0)[:,0] for x in arrClustersAmanualStableNoNA]

idxJumpP = [np.argwhere(np.diff(x) != 0)[:,0] for x in arrClustersPmanualNoNA]
idxJumpStableP = [np.argwhere(np.diff(x) != 0)[:,0] for x in arrClustersPmanualStableNoNA]

In [None]:
jumpDistsA = [np.linalg.norm(tsne[jumps] - tsne[jumps+1], axis=1) for tsne, jumps in zip(arrTsneA, idxJumpA)]
jumpDistsStableA = [np.linalg.norm(tsne[jumps] - tsne[jumps+1], axis=1) for tsne, jumps in zip(arrTsneA, idxJumpStableA)]

jumpDistsP = [np.linalg.norm(tsne[jumps] - tsne[jumps+1], axis=1) for tsne, jumps in zip(arrTsneP, idxJumpP)]
jumpDistsStableP = [np.linalg.norm(tsne[jumps] - tsne[jumps+1], axis=1) for tsne, jumps in zip(arrTsneP, idxJumpStableP)]

In [None]:
densTsneA = [np.histogram2d(x[:,0], x[:,1], bins=(200, 200), range=((0, 1), (0, 1)))[0] for x in arrTsneA]
densTsneP = [np.histogram2d(x[:,0], x[:,1], bins=(200, 200), range=((0, 1), (0, 1)))[0] for x in arrTsneP]

densTsneA = [x / np.max(x) for x in densTsneA]
densTsneP = [x / np.max(x) for x in densTsneP]

densTsneA = np.mean(densTsneA, axis=0)
densTsneP = np.mean(densTsneP, axis=0)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(densTsneP > np.max(densTsneP) * 0.001)
ax[1].imshow(densTsneA > np.max(densTsneA) * 0.001)
ax[0].set_title('Posterior Density > 0.1% of max')
ax[1].set_title('Anterior Density > 0.1% of max')
fig.show()

In [None]:
# Compute diameter of circle around the representative tSNE mass 
maxJumpA = 2 * np.sqrt(miniball.Miniball(np.argwhere(densTsneA > np.max(densTsneA) * 0.001)).squared_radius()) / 200.0
maxJumpP = 2 * np.sqrt(miniball.Miniball(np.argwhere(densTsneP > np.max(densTsneP) * 0.001)).squared_radius()) / 200.0
maxJumpA, maxJumpP
# = (0.550, 0.597)

In [None]:
binsDistsA = np.histogram(np.hstack(jumpDistsA), bins='fd', range=(0, maxJumpA))[0].size
binsDistsP = np.histogram(np.hstack(jumpDistsP), bins='fd', range=(0, maxJumpP))[0].size

binsDistsStableA = np.histogram(np.hstack(jumpDistsStableA), bins='fd', range=(0, maxJumpA))[0].size
binsDistsStableP = np.histogram(np.hstack(jumpDistsStableP), bins='fd', range=(0, maxJumpP))[0].size

binsDistsA = max(binsDistsA, binsDistsStableA)
binsDistsStableA = binsDistsA

binsDistsP = max(binsDistsP, binsDistsStableP)
binsDistsStableP = binsDistsP

In [None]:
histDistsA = [np.histogram(x, bins=binsDistsA, range=(0, maxJumpA))[0] for x in jumpDistsA]
histDistsP = [np.histogram(x, bins=binsDistsP, range=(0, maxJumpP))[0] for x in jumpDistsP]

histDistsStableA = [np.histogram(x, bins=binsDistsStableA, range=(0, maxJumpA))[0] for x in jumpDistsStableA]
histDistsStableP = [np.histogram(x, bins=binsDistsStableP, range=(0, maxJumpP))[0] for x in jumpDistsStableP]

histDistsA = [x / np.sum(x) for x in histDistsA]
histDistsP = [x / np.sum(x) for x in histDistsP]

histDistsStableA = [x / np.sum(x) for x in histDistsStableA]
histDistsStableP = [x / np.sum(x) for x in histDistsStableP]

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 4))

m = np.median(np.array(histDistsStableA), axis=0)
s0 = np.percentile(np.array(histDistsStableA), 25, axis=0)
s1 = np.percentile(np.array(histDistsStableA), 75, axis=0)
ax[0].fill_between(np.linspace(0, 1, binsDistsStableA), s0, s1, color='gray', alpha=0.25)
ax[0].plot(np.linspace(0, 1, binsDistsStableA), m, color='red')

m = np.median(np.array(histDistsA), axis=0)
s0 = np.percentile(np.array(histDistsA), 25, axis=0)
s1 = np.percentile(np.array(histDistsA), 75, axis=0)
ax[0].fill_between(np.linspace(0, 1, binsDistsA), s0, s1, color='gray', alpha=0.25)
ax[0].plot(np.linspace(0, 1, binsDistsA), m, color='blue')

ax[0].set_title('Anterior Jumps (Stable=Red)')

m = np.median(np.array(histDistsStableP), axis=0)
s0 = np.percentile(np.array(histDistsStableP), 25, axis=0)
s1 = np.percentile(np.array(histDistsStableP), 75, axis=0)
ax[1].fill_between(np.linspace(0, 1, binsDistsStableP), s0, s1, color='gray', alpha=0.25)
ax[1].plot(np.linspace(0, 1, binsDistsStableP), m, color='red')

m = np.median(np.array(histDistsP), axis=0)
s0 = np.percentile(np.array(histDistsP), 25, axis=0)
s1 = np.percentile(np.array(histDistsP), 75, axis=0)
ax[1].fill_between(np.linspace(0, 1, binsDistsP), s0, s1, color='gray', alpha=0.25)
ax[1].plot(np.linspace(0, 1, binsDistsP), m, color='blue')

ax[1].set_title('Posterior Jumps (Stable=Red)')

ax[0].set_ylabel('PDF')

ax[0].set_xlabel('Transition Jump Distance in t-SNE space, \nNormalized to Diameter of t-SNE Density Outline')
ax[1].set_xlabel('Transition Jump Distance in t-SNE space, \nNormalized to Diameter of t-SNE Density Outline')

fig.show()
#fig.savefig('C:/Users/acorver/Desktop/paper-figures/transitions_distribution.pdf', bbox_inches = 'tight')

In [None]:
def getTransitionMatrix(seqAll, idx, N=None):
    seq = seqAll[idx]
    if N is None:
        N = np.max([np.max(x) for x in seqAll]) + 1
    mtx = np.zeros((N, N), dtype=np.float64)
    for i in range(1, len(seq)):
        if seq[i] != seq[i-1]:
            mtx[seq[i-1], seq[i]] += 1
    mtxRates = mtx.copy()
    for k in range(mtx.shape[0]):
        if np.sum(mtx[k, :]) > 0:
            mtx[k, :] = mtx[k, :] / np.sum(mtx[k, :])
    return mtx, mtxRates

In [None]:
transitionMtxA = [getTransitionMatrix(arrClustersAmanualNoNA, i, 
    np.max([np.max(x)+1 for x in arrClustersAmanualNoNA])) for i in tqdm(
    range(len(arrClustersAmanualNoNA)), leave=False)]
transitionMtxP = [getTransitionMatrix(arrClustersPmanualNoNA, i,
    np.max([np.max(x)+1 for x in arrClustersPmanualNoNA])) for i in tqdm(
    range(len(arrClustersPmanualNoNA)), leave=False)]

In [None]:
transitionMtxArates = [x[1] for x in transitionMtxA]
transitionMtxA      = [x[0] for x in transitionMtxA]
transitionMtxPrates = [x[1] for x in transitionMtxP]
transitionMtxP      = [x[0] for x in transitionMtxP]

transitionMtxA = np.mean(transitionMtxA, axis=0)
transitionMtxP = np.mean(transitionMtxP, axis=0)

transitionMtxArates = np.mean(transitionMtxArates, axis=0)
transitionMtxPrates = np.mean(transitionMtxPrates, axis=0)

# Renormalize (only affects small number of rows)
transitionMtxA /= np.sum(transitionMtxA, axis=1)[:, np.newaxis]
transitionMtxP /= np.sum(transitionMtxP, axis=1)[:, np.newaxis]

In [None]:
transitionMtxA[np.isnan(transitionMtxA)] = 0
transitionMtxP[np.isnan(transitionMtxP)] = 0

In [None]:
arrWatershedA = np.load(fnameSharedTsneA)
arrWatershedP = np.load(fnameSharedTsneP)

In [None]:
peaksA = np.array([np.median(np.vstack(arrTsneA)[np.isin(np.hstack(
    arrClustersAmanualNoNA), k), :], axis=0) for k in tqdm(range(transitionMtxA.shape[0]), leave=False)])
peaksP = np.array([np.median(np.vstack(arrTsneP)[np.isin(np.hstack(
    arrClustersPmanualNoNA), k), :], axis=0) for k in tqdm(range(transitionMtxP.shape[0]), leave=False)])

In [None]:
# Load densities
Hs = []
for arr in arrTsneA:
    arr = arr[~np.any(np.isnan(arr), axis=1)]
    H, _, _ = np.histogram2d(arr[:,0], arr[:,1], bins=(200,200), range=((0, 1), (0, 1)))
    H = np.clip(H, 0, np.percentile(H, 99))
    Hs.append(H)
densityA = np.mean(Hs, axis=0)

Hs = []
for arr in arrTsneP:
    arr = arr[~np.any(np.isnan(arr), axis=1)]
    H, _, _ = np.histogram2d(arr[:,0], arr[:,1], bins=(200,200), range=((0, 1), (0, 1)))
    H = np.clip(H, 0, np.percentile(H, 99))
    Hs.append(H)
densityP = np.mean(Hs, axis=0)

In [None]:
def maskToPerimeter(mask):
    mask[0,:] = False
    mask[-1,:] = False
    mask[:,0] = False
    mask[:,-1] = False
    
    polys = [geom.Polygon([[_x + dx, _y + dy] for dx, dy in [[0,0],[0,1],[1,1],[1,0]]]) for _x, _y in \
         np.argwhere(mask)]
    a = shapely.ops.cascaded_union(polys).exterior.coords.xy
    xy = np.hstack((np.array(a[0])[:,np.newaxis], np.array(a[1])[:,np.newaxis]))
    
    return xy

In [None]:
import colorcet
import matplotlib.colors
import matplotlib.cm as cm

class customColormap(matplotlib.colors.LinearSegmentedColormap):
    def __init__(self, *args, **kwargs):
        self.baseCM = cm.get_cmap('cet_CET_L19')
        self.N = self.baseCM.N
    def __call__(self, r, *args, **kwargs):
        def mapColor(x):
            _c = self.baseCM(0.0, *args, **kwargs)
            lim = 0.10
            if x >= lim:
                return self.baseCM((x-lim)/(1.0 - lim), *args, **kwargs)
            else: 
                if isinstance(_c[0], float):
                    z = (x / lim)
                    a = _c[0] * z + 1.0 * (1 - z)
                    b = _c[1] * z + 1.0 * (1 - z)
                    c = _c[2] * z + 1.0 * (1 - z)
                    _c = (a, b, c, 1.0)
                    return _c
                else:
                    z = (x / lim)
                    a = int(_c[0] * z + 255.0 * (1 - z))
                    b = int(_c[1] * z + 255.0 * (1 - z))
                    c = int(_c[2] * z + 255.0 * (1 - z))
                    _c = (a, b, c, 255)
                    return _c
        if r.ndim == 2:
            _c = self.baseCM(0.0, *args, **kwargs)
            cs = np.array([[mapColor(y) for y in x] for x in r], 
                dtype=np.float64 if isinstance(_c[0], float) else np.uint8)
            return cs
        elif r.ndim == 1:
            _c = self.baseCM(0.0, *args, **kwargs)
            cs = np.array([mapColor(y) for y in r], 
                dtype=np.float64 if isinstance(_c[0], float) else np.uint8)
            return cs
        else:
            print('!!')

In [None]:
import regex as re

In [None]:
probThreshold = 0.05
cmDensity = plt.get_cmap('cet_CET_L19') #customColormap()

def plot(ax, anterior):
    ax.set_axis_off()
    
    transitionsPlt = (transitionMtxA if anterior else transitionMtxP).copy().astype(np.float64)
    transitionsRate= (transitionMtxArates if anterior else transitionMtxPrates).copy().astype(np.float64)
    peaks = (peaksA if anterior else peaksP) * 200
    density = (densityA if anterior else densityP)
    clusterLabels = (clusterLabelsA if anterior else clusterLabelsP)
    arrWatershed = (arrWatershedA if anterior else arrWatershedP)
    
    # Plot density
    _dens = np.flip(density, axis=0)
    _dens = np.clip(_dens, 0, np.percentile(_dens, 99.9))
    _dens/= _dens.max()
    ax.imshow(_dens, extent=(0, 200, 0, 200), cmap=cmDensity)
    
    # Plot cluster boundaries
    for z in range(0, 2):
        for k, clusterIDsKey in enumerate([x for x in clusterLabels if x != 'noisy']):
            clusterIDs = clusterLabels[clusterIDsKey]
            # Plot clusters to highlight
            mask = None
            for exID, clusterID in enumerate(clusterIDs):
                _mask = (arrWatershed[:,:,0,1] == clusterID)
                # Merge masks
                mask = _mask if mask is None else (mask | _mask)

            # Split masks into contiguous submasks
            maskLabeled = label(binary_dilation(binary_erosion(mask, iterations=1), iterations=1))
            maskLabeledBg = np.max([np.sum(maskLabeled == x) for x in np.unique(maskLabeled)])
            isLabeled = False
            for maskID in np.unique(maskLabeled):
                if np.sum(maskLabeled == maskID) != maskLabeledBg:
                    mask = (maskLabeled == maskID)
                    mask[:25,:] = False
                    mask[175:,:] = False
                    mask[:,:25] = False
                    mask[:,175:] = False
                    pts = maskToPerimeter(mask)
                    if pts is not None:
                        clusterIDsKeyBase = re.search('.*(?=-[0-9]*)', clusterIDsKey).group(0)
                        ax.plot(pts[:,1], pts[:,0], linewidth=2, \
                            color=COLORS_CLUSTERS[clusterIDsKeyBase] + '88')
                        ax.fill(pts[:,1], pts[:,0], linewidth=1, \
                            edgecolor=COLORS_CLUSTERS[clusterIDsKeyBase], 
                            facecolor=COLORS_CLUSTERS[clusterIDsKeyBase] + '11')
                        isLabeled = True

    # Sort arrows to plot by increasing thickness
    ijs = np.array(sorted(np.argwhere(transitionsPlt), key=lambda x: transitionsPlt[x[0],x[1]]))
        
    for (i, j) in ijs:
        if transitionsPlt[i,j] >= probThreshold and transitionsRate[i,j] >= 10: 
            try:
                ax.annotate("",
                    xy=(peaks[i,1],peaks[i,0]), xycoords='data',
                    xytext=(peaks[j,1],peaks[j,0]), textcoords='data',
                    arrowprops=dict(
                        arrowstyle="-", color="#222222",
                        shrinkA=5, shrinkB=5,
                        patchA=None, patchB=None,
                        connectionstyle='arc3, rad=0.2',
                        linewidth=transitionsPlt[i, j] * 10
                        ))
            except Exception as e:
                print(e)
            
    for p in range(peaks.shape[0]):
        ax.scatter(peaks[p,1], peaks[p,0], color='#ff3333', s=7)
    
fig, ax = plt.subplots(1, 3, figsize=(16, 8))

ax[0].set_title('Anterior Transitions (>= {})'.format(probThreshold))
plot(ax[0], True)
ax[0].set_xlim(25, 175)
ax[0].set_ylim(25, 175)

ax[1].set_title('Posterior Transitions (>= {})'.format(probThreshold))
plot(ax[1], False)
ax[1].set_xlim(25, 175)
ax[1].set_ylim(25, 175)

# Plot legend
ax[2].set_axis_off()
ax[2].set_xlim(25, 175)
ax[2].set_ylim(25, 175)

for z, (p, invert) in enumerate([(0.1, True), (0.1, False), (0.5, False), (1.0, False)]):
    x1, y1, x2, y2 = 30, 80 + z * 20, 50, 80 + z * 20
    if invert:
        a, b = x1, y1
        x1, y1 = x2, y2
        x2, y2 = a, b
    ax[2].scatter(x1, y1, color='#ff3333', s=20, zorder=10)
    ax[2].scatter(x2, y2, color='#ff3333', s=20, zorder=10)
    ax[2].annotate("",
        xy=(x1, y1), xycoords='data',
        xytext=(x2, y2), textcoords='data',
        arrowprops=dict(
            arrowstyle="-", color="#222222",
            shrinkA=3 * p, shrinkB=5 * p,
            patchA=None, patchB=None,
            connectionstyle='arc3, rad=0.2',
            linewidth=p * 10
            ))
    ax[2].text(min(x1, x2), y1 + 6, 'A', ha='center', fontsize=16)
    ax[2].text(max(x1, x2), y2 + 6, 'B', ha='center', fontsize=16)
    ax[2].text(max(x1, x2) + 20, y2, 'P({}) = {}'.format(
        'B → A' if invert else 'A → B', p), ha='left', va='center', fontsize=16)

fig.tight_layout()
fig.savefig('C:/Users/acorver/Desktop/paper-figures/transitions_arrows_v2_{}{}.pdf'.format(
    probThreshold, '_mergemanual' if MERGE_MANUAL else ''), bbox_inches = 'tight')

In [None]:
def plot(ax, anterior):
    ax.set_axis_off()
    
    transitionsPlt = (transitionMtxA if anterior else transitionMtxP).copy().astype(np.float64)
    transitionsRate= (transitionMtxArates if anterior else transitionMtxPrates).copy().astype(np.float64)
    peaks = (peaksA if anterior else peaksP) * 200
    density = (densityA if anterior else densityP)
    clusterLabels = (clusterLabelsA if anterior else clusterLabelsP)
    arrWatershed = (arrWatershedA if anterior else arrWatershedP)

    # Plot density
    ax.imshow(np.flip(density, axis=0), extent=(0, 200, 0, 200), cmap='gray_r')

    # Plot cluster boundaries
    for z in range(2):
        for k, clusterIDsKey in enumerate([x for x in clusterLabels if x != 'noisy']):
            clusterIDs = clusterLabels[clusterIDsKey]
            # Plot clusters to highlight
            mask = None
            for exID, clusterID in enumerate(clusterIDs):
                _mask = (arrWatershed[:,:,0,1] == clusterID)
                # Merge masks
                mask = _mask if mask is None else (mask | _mask)

            # Split masks into contiguous submasks
            maskLabeled = label(binary_dilation(binary_erosion(mask, iterations=1), iterations=1))
            maskLabeledBg = np.max([np.sum(maskLabeled == x) for x in np.unique(maskLabeled)])
            isLabeled = False
            for maskID in np.unique(maskLabeled):
                if np.sum(maskLabeled == maskID) != maskLabeledBg:
                    mask = (maskLabeled == maskID)
                    mask[:25,:] = False
                    mask[175:,:] = False
                    mask[:,:25] = False
                    mask[:,175:] = False
                    pts = maskToPerimeter(mask)
                    if pts is not None:
                        ax.plot(pts[:,1], pts[:,0], linewidth=2, \
                            color='#aaaaaa', linestyle='--', zorder=5)
                        clusterIDsKeyBase = re.search('.*(?=-[0-9]*)', clusterIDsKey).group(0)
                        ax.fill(pts[:,1], pts[:,0], linewidth=1, \
                            edgecolor=COLORS_CLUSTERS[clusterIDsKeyBase], 
                            facecolor=COLORS_CLUSTERS[clusterIDsKeyBase] + '22', zorder=5)
                        isLabeled = True

    # Arrow Colormap
    cm = plt.get_cmap('cet_CET_L18')

    # Sort arrows to plot by increasing thickness
    ijs = np.array(sorted(np.argwhere(transitionsPlt), key=lambda x: transitionsPlt[x[0],x[1]]))
    
    for (i, j) in ijs:
        if transitionsPlt[i,j] > 0.05 and transitionsRate[i,j] >= 10:
            ax.annotate("",
                xytext=(peaks[i,1],peaks[i,0]), xycoords='data',
                xy=(peaks[j,1],peaks[j,0]), textcoords='data',
                arrowprops=dict(fc=cm(int(transitionsPlt[i,j] * 255.0)), 
                                mutation_scale=transitionsPlt[i,j] * 10, 
                                ec='#222222',
                                linewidth=0.3,
                                width=transitionsPlt[i,j] * 10,
                                headwidth=5 + transitionsPlt[i,j] * 10,
                                connectionstyle="arc3, rad=-0.2",
                                shrinkA=15, shrinkB=15), zorder=10)

    for p in range(peaks.shape[0]):
        pass #ax.scatter(peaks[p,1], peaks[p,0], color='black', s=10)
    
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].set_title('Anterior Transitions')
plot(ax[0], True)
ax[0].set_xlim(25, 175)
ax[0].set_ylim(25, 175)

ax[1].set_title('Posterior Transitions')
plot(ax[1], False)
ax[1].set_xlim(25, 175)
ax[1].set_ylim(25, 175)

cbar_ax = fig.add_axes([0.99, 0.3, 0.02, 0.4])
cb = mpl.colorbar.ColorbarBase(
    cbar_ax, cmap=plt.get_cmap('cet_CET_L18'),
    orientation='vertical', ticks=[0, 1], 
    label='Transition Probability Normalized to Maximum')
cb.ax.set_yticklabels(['0', '100%'])

fig.tight_layout()
fig.savefig('C:/Users/acorver/Desktop/paper-figures/transitions_arrows_style2_{}{}z2.pdf'.format(
    probThreshold, '_mergemanual' if MERGE_MANUAL else ''), bbox_inches = 'tight')