In [None]:
#data format library
import h5py

#numpy
import numpy as np
import pandas as pd
import numpy.ma as ma
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
new_rc_params = {'text.usetex': False,
"svg.fonttype": 'none'
}
mpl.rcParams.update(new_rc_params)
plt.rcParams["font.family"] = "Times New Roman"
# %matplotlib notebook
import sys
sys.path.append('/Users/gautam.sridhar/Documents/Repos/Markov_Fish/utils/')
import matplotlib.colors as pltcolors
import os
import copy
import clustering_methods as cl
import operator_calculations as op_calc
import delay_embedding as embed
import stats
import time

np.random.seed(42)

import importlib
importlib.reload(op_calc)

In [None]:
path_to_filtered_data = '/Users/gautam.sridhar/Documents/Repos/ZebraBouts/Datasets/Full_Data/'
f = h5py.File(path_to_filtered_data+'filtered_jmpool_ex5_kin.h5','r')
lengths = np.array(f['MetaData/lengths_data'],dtype=int)
bouttypes= ma.array(f['bout_types'], dtype=int)
# stims = ma.array(f['stims'])
# ecs = ma.array(f['eye_convergence'])

##shapes are n_fish, max_n_bouts,dim of variable

bouttypes_allcond= ma.array(f['bout_types'])
stims_allcond = ma.array(f['stims'])
ecs_allcond = ma.array(f['eye_convergence'])
time_Bout_allcond = ma.array(f['times_bouts']) #raw times bouts

#shapes are n_fish, max_n_bouts,dim of variable
X_head_allcond = ma.array(f['head_pos'])
phi_smooth_allcond = ma.array(f['orientation_smooth'])
speeds_head_allcond = ma.array(f['speed_head'])
f.close()

In [None]:
time_Bout_allcond[time_Bout_allcond == 0] = ma.masked
X_head_allcond[X_head_allcond == 0] = ma.masked
phi_smooth_allcond[phi_smooth_allcond == 0] = ma.masked
speeds_head_allcond[phi_smooth_allcond.mask] = ma.masked
bouttypes_allcond[bouttypes_allcond == 15] = ma.masked
ecs_allcond[ecs_allcond == 100] = ma.masked

In [None]:
condition_labels = ['Light (5x5cm)','Light (1x5cm)','Looming(5x5cm)','ChasingDot coarsespeeds(5x5cm)','ChasingDot finespeeds(5x5cm)','Dark_Transitions(5x5cm)',
                    'Phototaxis','Optomotor Response (1x5cm)','Optokinetic Response (5x5cm)','Dark (5x5cm)','3 min Light<->Dark(5x5cm)',
                    'Prey Capture Param. (2.5x2.5cm)','Prey Capture Param. RW. (2.5x2.5cm)',
                    'Prey Capture Rot.(2.5x2.5cm)','Prey capture Rot. RW. (2.5x2.5cm)','Light RW. (2.5x2.5cm)']

condition_recs = np.array([[515,525],[160,172],[87,148],[43,60],[22,43],[60,87],
                           [202,232],[148,160],[172,202],[505,515],[0,22],
                           [232,301],[347,445],[301,316],[316,347],
                           [445,505]])

In [None]:
conditions = np.zeros((np.max(condition_recs),2),dtype='object')
for k in range(len(condition_recs)):
    t0,tf = condition_recs[k]
    conditions[t0:tf,0] = np.arange(t0,tf)
    conditions[t0:tf,1] = [condition_labels[k] for t in range(t0,tf)]

In [None]:
recs_remove = np.load('/Users/gautam.sridhar/Documents/Repos/ZebraBouts/Datasets/Full_Data/recs_remove.npy')

In [None]:
recs_remove = np.hstack([recs_remove, np.arange(22,60)])

In [None]:
print(recs_remove)

In [None]:
## Load symbolic sequences

path_to_filtered_data = '/Users/gautam.sridhar/Documents/Repos/Markov_Fish/Datasets/JM_Data/'
f = h5py.File(path_to_filtered_data + 'pool_ex8_PCs/kmeans_labels_K5_N1200_s8684.h5')
lengths_all = np.array(f['MetaData/lengths_data'], dtype=int)
labels_fish_allrec = ma.array(f['labels_fish'],dtype=int)
state_trajs = ma.array(f['state_trajs'])
f.close()

# lengths_all = np.load('/Users/gautam.sridhar/Documents/Repos/ZebraBouts/Datasets/Full_Data/lengths_ex2_recordings.npy')
# lengths_all = lengths

In [None]:
# recs_ = np.asarray(conditions[:,0], dtype=int)

to_mask = 1300

# maxL = np.max(lengths_all[recs_])
maxL = np.max(lengths_all)

labels_fish_allrec[labels_fish_allrec == to_mask] = ma.masked

# labels_fishrec = to_mask * ma.ones((len(recs_), maxL))
# labels_fishrec = labels_fish_allrec[recs_,:maxL+2]
# labels_fishrec = np.delete(labels_fishrec,4,0)

# labels_fishrec[labels_fishrec == to_mask] = ma.masked
labels_fish = labels_fish_allrec

# lengths_rem = np.delete(lengths_all, recs_remove)
lengths_rem = lengths_all

In [None]:
## Select Dataset
np.random.seed(42)
seeds = np.random.randint(0,10000,100)
delay_range = np.arange(1,20,1)
dt = 1
div= 463
n_modes=50
labels_all= ma.concatenate(labels_fish,axis=0)
print(labels_fish.shape)

In [None]:
def sampler(labels_fish, conditions ,lengths_all, window_size, to_mask, s):
    np.random.seed(s)
    condition_labels = ['Light (5x5cm)','Light (1x5cm)','Looming(5x5cm)','Dark_Transitions(5x5cm)',
                    'Phototaxis','Optomotor Response (1x5cm)','Optokinetic Response (5x5cm)','Dark (5x5cm)','3 min Light<->Dark(5x5cm)',
                    'Prey Capture Param. (2.5x2.5cm)','Prey Capture Param. RW. (2.5x2.5cm)',
                    'Prey Capture Rot.(2.5x2.5cm)','Prey capture Rot. RW. (2.5x2.5cm)','Light RW. (2.5x2.5cm)']
    sampled_labels = []
    for cond in condition_labels:
        cond_recs = np.where(conditions[:,1] == cond)[0]
        samples_ = to_mask*ma.ones((window_size+4,), dtype=int)
        labels_cond = labels_fish[cond_recs]
        rec_list = []
        for i, l in enumerate(lengths_all[cond_recs]):
            labels_cond[i,l-2:l] = ma.masked
            rec_list.append(labels_cond[i,:l])
        
        labels_ = ma.hstack(rec_list)
        start_pos = np.random.randint(0, len(labels_)-window_size)
        
        samples_[2:-2] = labels_[start_pos:start_pos+window_size]
        samples_[samples_==to_mask] = ma.masked
        sampled_labels.append(samples_)
    return sampled_labels

In [None]:
## Calculate implied timescales by sampling similar number of bouts from each condition

ts_traj_delay = ma.zeros((seeds.shape[0],len(delay_range),n_modes))
eigvals_shuffle = ma.zeros((seeds.shape[0],n_modes-1))

window_size = 7500

for i,s in enumerate(seeds):
    print(s)
    labels_bootstrap = sampler(labels_fish, conditions,lengths_rem, window_size, to_mask, s)
    labels_ = ma.hstack(labels_bootstrap)
    nstates = ma.max(labels_) + 1
    segments = op_calc.segment_maskedArray(labels_)

    dtrajs = np.asarray([labels_[t0:tf] for t0,tf in segments])
    if ma.count(labels_)>20:
        for kd,delay in enumerate(delay_range):
            ts_traj_delay[i,kd,:] = op_calc.implied_tscale_shuffle(dtrajs,nstates,delay,dt,n_modes,reversible=True)
#     labels_shuffle = labels_[np.random.choice(np.arange(len(labels_)),len(labels_))]
#     P_shuffle = op_calc.transition_matrix(labels_shuffle,1)
    P_shuffle = op_calc.transition_matrix_shuffle(dtrajs,1)
    R_shuffle = op_calc.get_reversible_transition_matrix(P_shuffle)
    eigvals,eigvecs = op_calc.sorted_spectrum(R_shuffle,k=n_modes)
    sorted_indices = np.argsort(eigvals.real)[::-1]
    eigvals = eigvals[sorted_indices][1:].real
    eigvals[np.abs(eigvals-1)<1e-12] = np.nan
    eigvals[eigvals<1e-12] = np.nan
    eigvals_shuffle[i,:] = eigvals

In [None]:
ts_traj_delay[ts_traj_delay==0] = ma.masked
eigvals_shuffle[eigvals_shuffle == 0] = ma.masked 

## Use if sampling similar number of bouts from each condition

ts_traj_delay_total = ts_traj_delay
eigvals_shuffle_total = eigvals_shuffle

print(ts_traj_delay_total.shape)
print(eigvals_shuffle_total.shape)

In [None]:
# Fig 2a 
fig, ax = plt.subplots(1,1,figsize=(10,10))

tau = 3

ts_traj_delay1 = ts_traj_delay_total[:,tau-1,:]
mean,cil,ciu=stats.bootstrap(ts_traj_delay1[:,:],median=False,n_times=1000)
ax.scatter(np.arange(n_modes),mean,c='k', s =50)
ax.errorbar(np.arange(n_modes), mean,[mean-cil, ciu - mean], fmt='.', capsize=7, elinewidth=3, color = 'k')

mean,cil,ciu=stats.bootstrap((tscales_shuffle.T),n_times=1000)
ax.axhline(mean[tau-1],c='k',alpha=.5,ls='--', label='Noise Floor')
ax.axhspan(cil[tau-1],ciu[tau-1],color='k',alpha=.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)

In [None]:
# Fig S2a - time scales for tau
colors_ = plt.cm.Reds_r(np.linspace(0,1,50))

fig, ax = plt.subplots(1,1,figsize=(10,10))
for mode in range(n_modes):
    mean = ma.mean(ts_traj_delay_total[:,:,mode],axis=0)
    mean,cil,ciu=stats.bootstrap(ts_traj_delay_total[:,:,mode].squeeze(),median=False,n_times=1000)
    ax.plot(delay_range*dt,mean,c=colors_[mode], marker='*')
    ax.errorbar(delay_range*dt, mean,[mean-cil, ciu - mean], marker='o', capsize=4, color = colors_[mode])

tscales_shuffle = np.array([-d/np.log(eigvals_shuffle_total[:,0]) for d in delay_range])
mean,cil,ciu=stats.bootstrap((tscales_shuffle.T),n_times=1000)

ax.plot(delay_range*dt,mean,c='k',alpha=.5,ls='--', label='Noise Floor')
ax.fill_between(delay_range*dt,cil,ciu,color='k',alpha=.3)
ax.legend(fontsize = 30)
ax.set_xlabel(r'$\tau$ (States)',fontsize=30)
ax.set_ylabel(r'$t^{imp} (\tau)$',fontsize=40)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xscale('log')
ax.set_ylim(0,15)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
plt.show()

In [None]:
P_ensemble = np.load('/Users/gautam.sridhar/Documents/Repos/ZebraBouts/Datasets/Full_Data/P_ensemble_ex8_N1200_s8684.npy')

In [None]:
# Fig 2a inset
fig, ax = plt.subplots(1,1,figsize=(5,5))

ax.imshow(P_ensemble, vmax = 0.003,cmap='magma')
ax.axis('off')

In [None]:
from scipy.sparse import diags,identity,coo_matrix, csr_matrix
P_ensemble = csr_matrix(P_ensemble)

In [None]:
import msmtools.estimation as msm_estimation
delay = 3
dt = 1
print(delay)
# lcs_ensemble,P_ensemble = op_calc.transition_matrix(labels_all,delay,return_connected=True)
lcs_ensemble = msm_estimation.largest_connected_set(P_ensemble)
inv_measure = op_calc.stationary_distribution(P_ensemble)
final_labels = op_calc.get_connected_labels(labels_all,lcs_ensemble)
R = op_calc.get_reversible_transition_matrix(P_ensemble)
eigvals,eigvecs = op_calc.sorted_spectrum(R,k=10,seed=123)
sorted_indices = np.argsort(eigvals.real)[::-1]
eigvals = eigvals[sorted_indices][1:].real
eigvals[np.abs(eigvals-1)<1e-12] = np.nan
eigvals[eigvals<1e-12] = np.nan
t_imp =  -(delay*dt)/np.log(np.abs(eigvals))
eigfunctions = eigvecs.real/np.linalg.norm(eigvecs.real,axis=0)
eigfunctions_traj = ma.array(eigfunctions)[final_labels,:]
eigfunctions_traj[final_labels.mask] = ma.masked

In [None]:
# Load eigenfunctions  
# eigfunctions = np.load('/Users/gautam.sridhar/Documents/Repos/ZebraBouts/Results/pool_ex7_PCs/eigfs_n1200.npy')
# phi1 = eigfunctions[:,1]
# phi2 = eigfunctions[:,2]
# phi3 = eigfunctions[:,3]

In [None]:
split_locs = []
distorted_eigfs = np.zeros((eigfunctions.shape[0], eigfunctions.shape[1]-1))
for i in range(1,eigfunctions.shape[1]):
    phi = eigfunctions[:,i]
    _,_,_,split_idx,_ = op_calc.optimal_partition(phi,inv_measure,P_ensemble,return_rho=True)

    sort_range = np.sort(phi)
    neg_range = np.linspace(-1,0, len(sort_range[0:split_idx]))
    pos_range = np.linspace(0,1,len(sort_range[split_idx:]))
    distort_r = np.hstack([neg_range,pos_range])
    distort = np.zeros(phi.shape)

    pos = [np.where(phi == a)[0][0] for a in np.sort(phi)]

    for j in range(phi.shape[0]):
        distort[pos[j]] = distort_r[j]

    distorted_eigfs[:,i-1] = distort
    split_locs.append(split_idx)

In [None]:
phi_1 = distorted_eigfs[:,0]
phi_2 = distorted_eigfs[:,1]

In [None]:
distorted_trajs = ma.array(distorted_eigfs)[final_labels,:]
distorted_trajs[final_labels.mask] = ma.masked

distorted_fish = distorted_trajs.reshape(labels_fish.shape[0], labels_fish.shape[1], -1)

### Kinematics calculation - Fig 2b,c

In [None]:
# Kinematics calculation

phi_smooth_allcond = np.delete(phi_smooth_allcond, recs_remove, axis=0)
phi_smooth_allcond[phi_smooth_allcond == 0] = ma.masked

In [None]:
phis_all = ma.concatenate(phi_smooth_allcond,axis=0)
delphi = (ma.abs(phis_all[1:,0] - phis_all[:-1,0]))*(180/np.pi)
psi = ma.zeros(phis_all.shape[0])
psi[:-1] = delphi
psi[-1] = ma.masked

In [None]:
K=5
meanKphi_fish = [ma.abs(psi[:,k:k+K]).mean(axis=1) for k in range(len(psi[0])-K)]
meanKphi_fish = ma.vstack(meanKphi_fish).T

meanKphi = -1*ma.ones((psi.shape[0],psi.shape[1]))
meanKphi[:,2:-3] = meanKphi_fish
print(meanKphi.shape)

meanKphi[meanKphi==-1] = ma.masked
meanKphi_all = ma.hstack(meanKphi)
phi_labels = np.asarray([ma.mean(meanKphi_all[labels_all==kl]) for kl in range(1200)])

In [None]:
phi_labels = np.asarray([ma.mean(phis_embed[labels_all==kl]) for kl in range(1200)])

In [None]:
## Kinetic Maps with diffusion distances
fig, ax = plt.subplots(1,1,figsize=(12,10))

im = ax.scatter(phi_1[:],phi_2[:],s=200,c=phi_labels,vmax = 1.3,vmin=0.5,alpha=1., cmap='viridis')
fig.colorbar(im)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ax.axis('off')
# 
# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Fig2/phi1_phi2_speed.pdf')


In [None]:
speeds_head_allcond = np.delete(speeds_head_allcond, recs_remove,axis=0)
speeds_head_allcond[phi_smooth_allcond.mask] = ma.masked

In [None]:
K=5
meanKspeed_fish = [speeds_head_allcond[:,k:k+K,:].mean(axis=1).mean(axis=1) for k in range(len(speeds_head_allcond[0])-K)]
meanKspeed_fish = ma.vstack(meanKspeed_fish).T

meanKspeed = -1*ma.ones((speeds_head_allcond.shape[0],speeds_head_allcond.shape[1]))
meanKspeed[:,2:-3] = meanKspeed_fish
print(meanKspeed_all.shape)

meanKspeed[meanKspeed==-1] = ma.masked
meanKspeed_all = ma.hstack(meanKspeed)
speed_labels = np.asarray([ma.mean(meanKspeed_all[labels_all==kl]) for kl in range(1200)])

In [None]:
speed_labels = np.asarray([ma.mean(speeds_embed[labels_all==kl]) for kl in range(1200)])

In [None]:
## Kinetic Maps with diffusion distances
fig, ax = plt.subplots(1,1,figsize=(12,10))

im = ax.scatter(phi_1[:],phi_2[:],s=200,c=speed_labels,vmax = 1.3,vmin=0.5,alpha=1., cmap='viridis')
fig.colorbar(im)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ax.axis('off')

# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Fig2/phi1_phi2_speed.pdf')

### Kinematics Fig 2d

In [None]:
K=5
meanKphi_fish = [psi[:,k:k+K].mean(axis=1) for k in range(len(psi[0])-K)]
meanKphi_fish = ma.vstack(meanKphi_fish).T

meanKphi = -1*ma.ones((psi.shape[0],psi.shape[1]))
meanKphi[:,2:-3] = meanKphi_fish
print(meanKphi.shape)

meanKphi[meanKphi==-1] = ma.masked
meanKphi_all = ma.hstack(meanKphi)

In [None]:
## Kinetic Maps with diffusion distances
fig, ax = plt.subplots(1,1,figsize=(12,10))

pmin = 0.0

distorted_ = ma.concatenate(distorted_fish[:,:,2], axis=0)

im = ax.scatter(distorted_.compressed()[0::50],mean_Kphi_all.compressed()[0::50],c='grey', alpha=0.1,s=5)
img = stats.density_plot(distorted_.compressed()[:-2],mean_Kphi_all.compressed()[:],[-1.1,1.1],[0,0.5],90,90,smooth=True, border=3)


X,Y = np.meshgrid(np.linspace(-1.1,1.1,96), np.linspace(0,0.5,96))
ax.contour(X,Y, img, cmap='magma', linewidths=4)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ax.axis('off')

# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Fig2/phi1_phi2_speed.pdf')


### Coherence based split Fig 2e

In [None]:
c_range_phi1,rho_sets,measures,split_idx_phi1,coh_labels_phi1 = op_calc.optimal_partition(phi1,inv_measure,P_ensemble,return_rho=True)
kmeans_labels = coh_labels_phi1

In [None]:
sort_clus_idx = np.argsort(coh_labels_phi1)[::-1]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(5,5))
D2 = copy.deepcopy(P_ensemble.todense())
D2[:] = D2[sort_clus_idx,:]
D2[:] = D2[:,sort_clus_idx]

ax.imshow(D2, vmax = 0.003,cmap='magma')
ax.axis('off')

In [None]:
cluster_traj_all = ma.copy(final_labels)
cluster_traj_all[~final_labels.mask] = ma.array(kmeans_labels)[final_labels[~final_labels.mask]]
cluster_traj_all[final_labels.mask] = ma.masked

cluster_fish = cluster_traj_all.reshape(labels_fish.shape[0],labels_fish.shape[1])
cluster_fish_mask = cluster_fish.mask

In [None]:
cond_recs = np.arange(463)

maxL = np.max(lengths_rem[cond_recs])

cluster_fish_condition = cluster_fish[cond_recs,:maxL]

cluster_traj = ma.concatenate(cluster_fish_condition, axis=0)
print(cluster_fish_condition.shape)

In [None]:
phi_smooth = phi_smooth_allcond
phi_smooth[phi_smooth == 0] = ma.masked
# phis = ma.concatenate(phi_smooth, axis=0)

In [None]:
phis_fish = [[],[],[],[],[],[],[],[]]
for cf in range(start_rec,end_rec):
    for ks in np.unique(cluster_traj.compressed()):
        print(ks, cf)
        sel = cluster_fish_condition[cf - start_rec] == ks
        segments = np.where(np.abs(np.diff(np.concatenate([[False], sel, [False]]))))[0].reshape(-1, 2)
        if len(segments) == 0:
            continue
        sorted_indices = np.argsort(np.hstack(np.diff(segments,axis=1)))[::-1]
        avg_orientation = [] 
        for idx in sorted_indices:
            t0,tf = segments[idx]
            for j in range(t0,tf):
                angle = phi_smooth[cf,j]
                angle2 = angle[~angle.mask]
                orient_diff = np.abs((angle2[-1] - angle[0])*(180/np.pi))
                avg_orientation.append(orient_diff)
        phis_fish[ks].append(ma.hstack(avg_orientation))

In [None]:
y_ebs_phis, x_all_phis = stats.dist_bootstrap(phis_fish, 0, 150, 2, 10)

In [None]:
fig,ax = plt.subplots(1,1, figsize=(10,10))

for ms in np.unique(cluster_traj_all.compressed()):
    mean = y_ebs_phis[ms][:-1,0]
    cil = y_ebs_phis[ms][:-1,1]
    ciu = y_ebs_phis[ms][:-1,2]
    ax.plot(x_all_phis[ms], 1-y_ebs_phis[ms][:,0], color='k', ls = '--', alpha=1.)
    ax.fill_between(x_all_phis[ms], 1-y_ebs_phis[ms][:,1], 1-y_ebs_phis[ms][:,2], alpha=0.7, color=st_colors[ms])
#     ax.errorbar(x_all_phis[ms][:-1], mean,[mean-cil, ciu - mean], fmt='.', capsize=10, color = st_colors[ms], alpha=0.2)
        
# ax.axis('off')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ax.set_xticklabels([])
# ax.set_yticklabels([])
# ax.set_yscale('log')
# ax.set_xlabel(r'$\Delta$ Orientation per bout ($\circ$)',fontsize=35)
# ax.set_ylabel('CDF',fontsize= 40)
# ax.legend(loc="upper left",bbox_to_anchor=(1.0, 1.05), fontsize=40, ncol=2)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
ax.set_xlim(-1,150)
ax.set_ylim(1e-4,1)

# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Fig2/WC_heading_new.pdf')
# plt.show()

### Generate lab space trajectories - Fig insets

In [None]:
X_head_allcond = np.delete(X_head_allcond, axis=0)
X_head_allcond[X_head_allcond == 0] = ma.masked 

In [None]:
X_head = X_head_allcond

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
# phi_traj_fish = eigfunctions_traj[:,1].reshape(labels_fish.shape[0], labels_fish.shape[1])
eigfish = distorted_fish[:,:,0]


rec = 442
print(len(eigfish[rec].compressed()))
st = 0
en = len(eigfish[rec].compressed())

# print(eigfish[rec].compressed().shape)

colors_ = eigfish[rec,st:en]
# print(colors_)
colors_ = ma.repeat(colors_[:,np.newaxis],X_head.shape[2],axis=1)
divnorm = pltcolors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1.)

X_toplot = X_head[rec,st:en,:,0]
Y_toplot = X_head[rec,st:en,:,1]

print(X_toplot.shape)

im = ax.scatter(ma.hstack(X_toplot),ma.hstack(Y_toplot),alpha=1.,c = colors_,ec='k', linewidths=0.05,s = 20., norm=divnorm, cmap='PuOr_r')
# im = ax.scatter(ma.hstack(X_toplot),ma.hstack(Y_toplot),alpha=1.,c = np.arange(0,colors_.size),s = 3., cmap='viridis')
im = ax.scatter(ma.hstack(X_toplot[0]),ma.hstack(Y_toplot[0]),alpha=.75,c = 'g',s = 3.)
im = ax.scatter(ma.hstack(X_toplot[-1]),ma.hstack(Y_toplot[-1]),alpha=.75,c = 'k',s = 3.)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

for axis in ['bottom','left']:
    ax.spines[axis].set_linewidth(4)
ax.set_xticks([0,950])
ax.set_xticklabels(['',''], fontsize=30)
ax.set_yticks([0,950])
ax.set_yticklabels(['',''], fontsize=30)
# ax.set_ylabel('y', fontsize=20)
ax.set_xlim(0,950)
ax.set_ylim(0,950)
plt.tight_layout()
plt.axis('off')
# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Suppl7/lightlarge_eg.png', dpi=300)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
# phi_traj_fish = eigfunctions_traj[:,1].reshape(labels_fish.shape[0], labels_fish.shape[1])
eigfish = distorted_fish[:,:,1]


rec = 442
print(len(eigfish[rec].compressed()))
st = 0
en = len(eigfish[rec].compressed())

# print(eigfish[rec].compressed().shape)

colors_ = eigfish[rec,st:en]
# print(colors_)
colors_ = ma.repeat(colors_[:,np.newaxis],X_head.shape[2],axis=1)
divnorm = pltcolors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1.)

X_toplot = X_head[rec,st:en,:,0]
Y_toplot = X_head[rec,st:en,:,1]

print(X_toplot.shape)

im = ax.scatter(ma.hstack(X_toplot),ma.hstack(Y_toplot),alpha=1.,c = colors_,ec='k', linewidths=0.05,s = 20., norm=divnorm, cmap='PuOr_r')
# im = ax.scatter(ma.hstack(X_toplot),ma.hstack(Y_toplot),alpha=1.,c = np.arange(0,colors_.size),s = 3., cmap='viridis')
im = ax.scatter(ma.hstack(X_toplot[0]),ma.hstack(Y_toplot[0]),alpha=.75,c = 'g',s = 3.)
im = ax.scatter(ma.hstack(X_toplot[-1]),ma.hstack(Y_toplot[-1]),alpha=.75,c = 'k',s = 3.)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

for axis in ['bottom','left']:
    ax.spines[axis].set_linewidth(4)
ax.set_xticks([0,950])
ax.set_xticklabels(['',''], fontsize=30)
ax.set_yticks([0,950])
ax.set_yticklabels(['',''], fontsize=30)
# ax.set_ylabel('y', fontsize=20)
ax.set_xlim(0,950)
ax.set_ylim(0,950)
plt.tight_layout()
plt.axis('off')
# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Suppl7/lightlarge_eg.png', dpi=300)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,10))
# phi_traj_fish = eigfunctions_traj[:,1].reshape(labels_fish.shape[0], labels_fish.shape[1])
eigfish = distorted_fish[:,:,2]


rec = 442
print(len(eigfish[rec].compressed()))
st = 0
en = len(eigfish[rec].compressed())

# print(eigfish[rec].compressed().shape)

colors_ = eigfish[rec,st:en]
# print(colors_)
colors_ = ma.repeat(colors_[:,np.newaxis],X_head.shape[2],axis=1)
divnorm = pltcolors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1.)

X_toplot = X_head[rec,st:en,:,0]
Y_toplot = X_head[rec,st:en,:,1]

print(X_toplot.shape)

im = ax.scatter(ma.hstack(X_toplot),ma.hstack(Y_toplot),alpha=1.,c = colors_,ec='k', linewidths=0.05,s = 20., norm=divnorm, cmap='PuOr_r')
# im = ax.scatter(ma.hstack(X_toplot),ma.hstack(Y_toplot),alpha=1.,c = np.arange(0,colors_.size),s = 3., cmap='viridis')
im = ax.scatter(ma.hstack(X_toplot[0]),ma.hstack(Y_toplot[0]),alpha=.75,c = 'g',s = 3.)
im = ax.scatter(ma.hstack(X_toplot[-1]),ma.hstack(Y_toplot[-1]),alpha=.75,c = 'k',s = 3.)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

for axis in ['bottom','left']:
    ax.spines[axis].set_linewidth(4)
ax.set_xticks([0,950])
ax.set_xticklabels(['',''], fontsize=30)
ax.set_yticks([0,950])
ax.set_yticklabels(['',''], fontsize=30)
# ax.set_ylabel('y', fontsize=20)
ax.set_xlim(0,950)
ax.set_ylim(0,950)
plt.tight_layout()
plt.axis('off')
# plt.savefig('/Users/gautam.sridhar/Documents/ZENITH/Figures/Suppl7/lightlarge_eg.png', dpi=300)
plt.show()