In [1]:
import sys
sys.path.append("../")
import recurrency_estimation as re
from tqdm.notebook import tqdm
import scipy.stats
from dask_jobqueue import SGECluster
from dask.distributed import Client, LocalCluster
import dask
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import xarray as xr
import itertools
import time
import matplotlib


from Session import SessionLite # required to unpickle data
from linear_model import PoolAcrossSessions, LinearModel, MultiSessionModel

import pop_off_functions as pof
import pop_off_plotting as pop

from sklearn.decomposition import FactorAnalysis

from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
from matplotlib.ticker import FixedFormatter, FixedLocator
from numpy import ma

from scipy.optimize import curve_fit

import pingouin as pg

/home/loidolt/RowlandEtAl/popping-off/popoff/popoff/loadpaths.py
/home/loidolt/RowlandEtAl/Vape


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams["figure.figsize"] = [3.4, 2.7]  # APS single column
matplotlib.rcParams["figure.dpi"] = 300  # this primarily affects the size on screen
#matplotlib.rcParams['axes.linewidth'] = 0.3
matplotlib.rcParams["axes.labelcolor"] = "black"
matplotlib.rcParams["axes.edgecolor"] = "black"
matplotlib.rcParams["xtick.color"] = "black"
matplotlib.rcParams["ytick.color"] = "black"
matplotlib.rcParams["xtick.labelsize"] = 10
matplotlib.rcParams["ytick.labelsize"] = 10
matplotlib.rcParams["axes.labelsize"] = 10
matplotlib.rcParams["axes.titlesize"]= 10
matplotlib.rcParams["legend.fontsize"] = 10
matplotlib.rcParams["legend.title_fontsize"] = 10
matplotlib.rcParams["axes.spines.right"] = False
matplotlib.rcParams["axes.spines.top"] = False

In [4]:
color_tt = {'hit': '#117733', 'miss': '#882255', 'fp': '#88CCEE', 'cr': '#DDCC77',
            'Hit': '#117733', 'Miss': '#882255', 'FP': '#88CCEE', 'CR': '#DDCC77',
            'urh': '#44AA99', 'arm': '#AA4499', 'spont': '#332288', 'prereward': '#332288', 
            'reward\nonly': '#332288', 'Reward\nonly': '#332288',
            'pre_reward': '#332288', 'Reward': '#332288', 'reward only': '#332288', 'rew. only': '#332288', 'hit&miss': 'k', 
            'fp&cr': 'k', 'photostim': sns.color_palette()[6], 'too_': 'grey',
            'hit_n1': '#b0eac9', 'hit_n2': '#5ab17f', 'hit_n3': '#117733',
            'miss_n1': '#a69098', 'miss_n2': '#985d76', 'miss_n3': '#882255',
            'hit_c1': '#b0eac9', 'hit_c2': '#5ab17f', 'hit_c3': '#117733',
            'miss_c1': '#a69098', 'miss_c2': '#985d76', 'miss_c3': '#882255'
            } 

Set the environment variable OUTDATED_RAISE_EXCEPTION=1 for a full traceback.
  **kwargs


In [5]:
#Figure 5a

## Sketch manually created in Affinity Designer.

In [6]:
#Figure 5b\
## load data 
hit_per_var_explained = np.load('../final_data/LFA.T_per_var_explained_hit_1000fits_5factors.npy')
miss_per_var_explained = np.load('../final_data/LFA.T_per_var_explained_miss_1000fits_5factors.npy')

## average across resamples
hit_factors_mu = np.nanmean(hit_per_var_explained, axis=1)
miss_factors_mu = np.nanmean(miss_per_var_explained, axis=1)

# plot variance explained
num_factors=5
n_sessions=11

fig, ax = plt.subplots(figsize=(2.25,1.5), dpi=300)

ax.plot(np.arange(1,num_factors+1), np.mean(miss_factors_mu, axis=0), color=color_tt['miss'], linewidth=2, label='miss')
ax.plot(np.arange(1,num_factors+1), np.mean(hit_factors_mu, axis=0), color=color_tt['hit'], linewidth=2, label='hit')
        
for i_session in range(n_sessions):
    ax.plot(np.arange(1,num_factors+1), hit_factors_mu[i_session], color=color_tt['hit'], alpha=0.5, linewidth=0.5)
    ax.plot(np.arange(1,num_factors+1), miss_factors_mu[i_session], color=color_tt['miss'], alpha=0.5, linewidth=0.5)

## Hide the right and top spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)

## Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
    
ax.set_xticks(np.arange(1,num_factors+1))
ax.set_yticks([0,25,50])
    
ax.set_xlim(1, num_factors)
ax.set_ylim(0, 50)

ax.set_xlabel("latent factor")
ax.set_ylabel("% var. expl.")

plt.legend()
plt.tight_layout()

plt.savefig('../final_reports/Figure5b.pdf')



In [7]:
#Figure 5c

## plot difference
stats = pg.wilcoxon(miss_factors_mu[:,0], hit_factors_mu[:,0])

p_val = stats['p-val']['Wilcoxon']

fig, ax = plt.subplots(figsize=(2.25,1.75), dpi=300)

ax.set_ylim(0,50)
ax.set_yticks([0,25,50])
ax.set_xticklabels(['miss','hit'])
ax.set_ylabel("% var. expl. \n by first factor")


tmp_df = pd.DataFrame(
        {
            "pervarexp":np.concatenate([miss_factors_mu[:,0],hit_factors_mu[:,0]]),
            "trial_type": ["miss"] * n_sessions + ["hit"] * n_sessions,
        }
    )
    
tmp_df['x'] = np.random.randn(2*n_sessions) * 0.1

ax.plot(tmp_df[tmp_df['trial_type']=='miss']['x'], tmp_df[tmp_df['trial_type']=='miss']['pervarexp'],
       '.', color='k',#('k' if bool_sign else 'grey'), 
                        markersize=10)
ax.plot(1+tmp_df[tmp_df['trial_type']=='hit']['x'], tmp_df[tmp_df['trial_type']=='hit']['pervarexp'],
       '.', color='k',#('k' if bool_sign else 'grey'), 
                        markersize=10)

ax.plot([tmp_df[tmp_df['trial_type']=='miss']['x'], 1+tmp_df[tmp_df['trial_type']=='hit']['x']],
        [tmp_df[tmp_df['trial_type']=='miss']['pervarexp'], tmp_df[tmp_df['trial_type']=='hit']['pervarexp']],
            c='k', alpha=0.7)

ax.set_title("p = {:.3f}".format(p_val), fontsize=10)

ax.set_xlabel('')
ax.set_xticks([0, 1])
ax.set_xticklabels(['miss', 'hit'])
ax.tick_params(bottom=False)
sns.despine()

plt.tight_layout()

plt.savefig('../final_reports/Figure5c.pdf')

In [8]:
#Figure 5d
##load data
remove_targets = False
pas = PoolAcrossSessions(save_PCA=False, subsample_sessions=False,
                         remove_targets=remove_targets, remove_toosoon=True)

sessions = {}
int_keys_pas_sessions = pas.sessions.keys()
# print(int_keys_pas_sessions)
i_s = 0
for ses in pas.sessions.values():  # load into sessions dict (in case pas skips an int as key)
    ses.signature = f'{ses.mouse}_R{ses.run_number}'
    sessions[i_s] = ses
    i_s += 1
print(sessions)
assert len(sessions) == 11
pof.label_urh_arm(sessions=sessions)  # label arm and urh

##select example session
n_session = 8

##rebin fluorescence
F = pas.linear_models[n_session].flu
F_rebinned = np.sum(F.reshape((F.shape[0], F.shape[1], F.shape[2]//3, 3)), axis=-1)


##find trials
s1_idx = np.nonzero(pas.linear_models[n_session].session.s1_bool)[0]
s2_idx = np.nonzero(pas.linear_models[n_session].session.s2_bool)[0]
hit_idx = np.nonzero(pas.linear_models[n_session].session.outcome == 'hit')[0]
miss_idx = np.nonzero(pas.linear_models[n_session].session.outcome == 'miss')[0]

HitAndMiss_idx = np.sort(np.concatenate((hit_idx, miss_idx), axis=0))

##re-slice fluorescence
pre_F_rebinned = F_rebinned[s1_idx][:,:,10:75]

pre_HitMiss_F_rebinned = F_rebinned[s1_idx][:,HitAndMiss_idx,:][:,:,10:75]

pre_F_concatenated = pre_F_rebinned.reshape((pre_F_rebinned.shape[0], pre_F_rebinned.shape[1]*pre_F_rebinned.shape[2]))

pre_HitMiss_F_concatenated = pre_HitMiss_F_rebinned.reshape((pre_HitMiss_F_rebinned.shape[0], pre_HitMiss_F_rebinned.shape[1]*pre_HitMiss_F_rebinned.shape[2]))

##run LFA
transformer = FactorAnalysis(n_components=5)
pre_F_factors = transformer.fit_transform(pre_HitMiss_F_concatenated.T)
pre_F_shared = np.dot(pre_F_factors,transformer.components_).T
pre_F_nonshared = pre_HitMiss_F_concatenated - pre_F_shared

def ML_CorrMat_Plot(F, n_neurons=50, vmin=-0.3, vmax=0.3, TvdP_norm=False):
    
    C = np.corrcoef(F)
    if TvdP_norm:
        a_bar = np.mean(C[np.eye(C.shape[0],dtype=bool)])
        C = C / a_bar
    
    fig,ax = plt.subplots(figsize=(2,2), dpi=300)
    mats = ax.imshow(C[:n_neurons,:n_neurons], cmap='RdYlBu_r', vmin=vmin, vmax=vmax)
    cb = plt.colorbar(mats, shrink=0.5, pad=0.15, aspect=8)
    
    ax.set_xticks([0,n_neurons/2,n_neurons])
    ax.set_yticks([0,n_neurons/2,n_neurons])

ML_CorrMat_Plot(pre_HitMiss_F_concatenated, vmin=-1, vmax=1)

plt.tight_layout()
plt.savefig("../final_reports/Figure5d.pdf")

long post time
long post time
long post time
long post time
long post time
Mouse RL070, run 29  registered no-lick hit. changed to too soon
long post time
long post time
Mouse RL117, run 29  registered no-lick hit. changed to too soon
Mouse RL117, run 29  registered no-lick hit. changed to too soon
long post time
long post time
long post time
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
long post time
ALERT SESSIONS NOT SUBSAMPLED
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
{0: instance Mouse J064, run 10 of Session class, 1: instance Mouse J064, run 11 of Session class, 2: instance Mouse J064, run 14 of Session class, 3: instance Mouse RL070, run 28 of Session class, 4: instance Mouse RL070, run 29 of Session class, 5:

In [10]:
#Figure 5e
##load data
res_dataset1 = xr.load_dataarray("../final_data/ML_nfac0-2_results_1000fits.nc")
res_dataset2 = xr.load_dataarray("../final_data/ML_nfac3-9_results_1000fits.nc")
res_dataset3 = xr.load_dataarray("../final_data/ML_nfac10-15_results_1000fits.nc")
res_dataset4 = xr.load_dataarray("../final_data/ML_nfac16-20_results_1000fits.nc")
res_dataset5 = xr.load_dataarray("../final_data/ML_nfac21-25_results_1000fits.nc")

res_dataset = xr.concat([res_dataset1, res_dataset2, res_dataset3, res_dataset4, res_dataset5], dim='n_fact')

##plot
fig = plt.figure(figsize=(7, 10))
re.plotting.figure_single_nfact_connecting_lines(res_dataset, fig, n_fact=5)

fig.tight_layout()
fig.savefig("../final_reports/Figure5e.pdf", dpi=300, bbox_inches="tight")

In [11]:
#Figure 5f
##load data
remove_targets = False
pas = PoolAcrossSessions(save_PCA=False, subsample_sessions=False,
                         remove_targets=remove_targets, remove_toosoon=True)
# print(pas.sessions)

## Create sessions object from PAS:
try:  # ensure sessions doesn't exist yet 
    sessions
    assert type(sessions) is dict
except NameError:
    pass

sessions = {}
int_keys_pas_sessions = pas.sessions.keys()
# print(int_keys_pas_sessions)
i_s = 0
for ses in pas.sessions.values():  # load into sessions dict (in case pas skips an int as key)
    ses.signature = f'{ses.mouse}_R{ses.run_number}'
    sessions[i_s] = ses
    i_s += 1
print(sessions)
assert len(sessions) == 11
pof.label_urh_arm(sessions=sessions)  # label arm and urh

##select example session
n_session = 8

##rebin fluorescence
F = pas.linear_models[n_session].flu
F_rebinned = np.sum(F.reshape((F.shape[0], F.shape[1], F.shape[2]//3, 3)), axis=-1)


##find trials
s1_idx = np.nonzero(pas.linear_models[n_session].session.s1_bool)[0]
s2_idx = np.nonzero(pas.linear_models[n_session].session.s2_bool)[0]
hit_idx = np.nonzero(pas.linear_models[n_session].session.outcome == 'hit')[0]
miss_idx = np.nonzero(pas.linear_models[n_session].session.outcome == 'miss')[0]

HitAndMiss_idx = np.sort(np.concatenate((hit_idx, miss_idx), axis=0))

##re-slice fluorescence
pre_F_rebinned = F_rebinned[s1_idx][:,:,10:75]

pre_HitMiss_F_rebinned = F_rebinned[s1_idx][:,HitAndMiss_idx,:][:,:,10:75]

pre_F_concatenated = pre_F_rebinned.reshape((pre_F_rebinned.shape[0], pre_F_rebinned.shape[1]*pre_F_rebinned.shape[2]))

pre_HitMiss_F_concatenated = pre_HitMiss_F_rebinned.reshape((pre_HitMiss_F_rebinned.shape[0], pre_HitMiss_F_rebinned.shape[1]*pre_HitMiss_F_rebinned.shape[2]))

##run LFA
transformer = FactorAnalysis(n_components=5)
pre_F_factors = transformer.fit_transform(pre_HitMiss_F_concatenated.T)
pre_F_shared = np.dot(pre_F_factors,transformer.components_).T
pre_F_nonshared = pre_HitMiss_F_concatenated - pre_F_shared

##plot
def ML_CovMat_Plot(F, n_neurons=50, vmin=-0.3, vmax=0.3, TvdP_norm=True):
    
    C = np.cov(F)
    if TvdP_norm:
        a_bar = np.mean(C[np.eye(C.shape[0],dtype=bool)])
        C = C / a_bar
    
    fig,ax = plt.subplots(figsize=(2,2), dpi=300)
    mats = ax.imshow(C[:n_neurons,:n_neurons], cmap='RdYlBu_r', vmin=vmin, vmax=vmax)
    cb = plt.colorbar(mats, shrink=0.5, pad=0.15, aspect=8)
    
    ax.set_xticks([0,n_neurons/2,n_neurons])
    ax.set_yticks([0,n_neurons/2,n_neurons])

ML_CovMat_Plot(pre_F_nonshared, vmin=-0.3, vmax=0.3, TvdP_norm=True)

plt.tight_layout()
plt.savefig("../final_reports/Figure5f.pdf")

long post time
long post time
long post time
long post time
long post time
Mouse RL070, run 29  registered no-lick hit. changed to too soon
long post time
long post time
Mouse RL117, run 29  registered no-lick hit. changed to too soon
Mouse RL117, run 29  registered no-lick hit. changed to too soon
long post time
long post time
long post time
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
long post time
ALERT SESSIONS NOT SUBSAMPLED
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
{0: instance Mouse J064, run 10 of Session class, 1: instance Mouse J064, run 11 of Session class, 2: instance Mouse J064, run 14 of Session class, 3: instance Mouse RL070, run 28 of Session class, 4: instance Mouse RL070, run 29 of Session class, 5:

In [12]:
#Figure 5g
##load data
remove_targets = False
pas = PoolAcrossSessions(save_PCA=False, subsample_sessions=False,
                         remove_targets=remove_targets, remove_toosoon=True)

## Create sessions object from PAS:
try:  # ensure sessions doesn't exist yet 
    sessions
    assert type(sessions) is dict
except NameError:
    pass

sessions = {}
int_keys_pas_sessions = pas.sessions.keys()
# print(int_keys_pas_sessions)
i_s = 0
for ses in pas.sessions.values():  # load into sessions dict (in case pas skips an int as key)
    ses.signature = f'{ses.mouse}_R{ses.run_number}'
    sessions[i_s] = ses
    i_s += 1
print(sessions)
assert len(sessions) == 11
pof.label_urh_arm(sessions=sessions)  # label arm and urh

##select example session
n_session = 8

##rebin fluorescence
F = pas.linear_models[n_session].flu
F_rebinned = np.sum(F.reshape((F.shape[0], F.shape[1], F.shape[2]//3, 3)), axis=-1)


##find trials
s1_idx = np.nonzero(pas.linear_models[n_session].session.s1_bool)[0]
s2_idx = np.nonzero(pas.linear_models[n_session].session.s2_bool)[0]
hit_idx = np.nonzero(pas.linear_models[n_session].session.outcome == 'hit')[0]
miss_idx = np.nonzero(pas.linear_models[n_session].session.outcome == 'miss')[0]

HitAndMiss_idx = np.sort(np.concatenate((hit_idx, miss_idx), axis=0))

##re-slice fluorescence
pre_F_rebinned = F_rebinned[s1_idx][:,:,10:75]

pre_HitMiss_F_rebinned = F_rebinned[s1_idx][:,HitAndMiss_idx,:][:,:,10:75]

pre_F_concatenated = pre_F_rebinned.reshape((pre_F_rebinned.shape[0], pre_F_rebinned.shape[1]*pre_F_rebinned.shape[2]))

pre_HitMiss_F_concatenated = pre_HitMiss_F_rebinned.reshape((pre_HitMiss_F_rebinned.shape[0], pre_HitMiss_F_rebinned.shape[1]*pre_HitMiss_F_rebinned.shape[2]))

##run LFA
transformer = FactorAnalysis(n_components=5)
pre_F_factors = transformer.fit_transform(pre_HitMiss_F_concatenated.T)
pre_F_shared = np.dot(pre_F_factors,transformer.components_).T
pre_F_nonshared = pre_HitMiss_F_concatenated - pre_F_shared

##plot
def ML_CovDist_OffDiagPlot(F, vmin=-0.3, vmax=0.3, TvdP_norm=False):
    
    C = np.cov(F)
    if TvdP_norm:
        a_bar = np.mean(C[np.eye(C.shape[0],dtype=bool)])
        C = C / a_bar
    
    off_diag = C[~np.eye(C.shape[0],dtype=bool)].ravel()
    off_X = np.linspace(np.min(off_diag), np.max(off_diag), 1000)
    off_Hist, _ = np.histogram(off_diag, off_X, density=True)
    
    dc = np.std(off_diag)
    
    fig, ax = plt.subplots(figsize=(1.8,1.6), dpi=300)
    plt.plot(off_X[:-1], off_Hist, linewidth=2, color='gold')
    
    plt.arrow(-dc, np.max(off_Hist)*1.2, 2*dc, 0, 
              color='grey', head_length=0.0175, head_width=1.5, length_includes_head = True)
    plt.arrow(dc, np.max(off_Hist)*1.2, -2*dc, 0, 
              color='grey', head_length=0.0175, head_width=1.5, length_includes_head = True)
    
    ax.text(0, np.max(off_Hist)*1.3, r"$\sigma_{CC}$", horizontalalignment='center', fontsize=8, color='grey')
    
    ax.set_xlim([vmin, vmax])
    ax.set_ylim([0,np.max(off_Hist*1.5)])
    
    ax.set_xticks([vmin,0,vmax])
    ax.set_yticks([0,np.max(off_Hist*1.5)])
    ax.set_yticklabels([0,1])
    
    plt.ylabel('density [a.u.]')
    #plt.xlabel('cross-covariance \n' + r"($c_{ij, j \neq i}$)")
    plt.xlabel('off-diag covariance')

ML_CovDist_OffDiagPlot(pre_F_nonshared, TvdP_norm=True)

plt.tight_layout()
plt.savefig("../final_reports/Figure5g.pdf")

long post time
long post time
long post time
long post time
long post time
Mouse RL070, run 29  registered no-lick hit. changed to too soon
long post time
long post time
Mouse RL117, run 29  registered no-lick hit. changed to too soon
Mouse RL117, run 29  registered no-lick hit. changed to too soon
long post time
long post time
long post time
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
long post time
ALERT SESSIONS NOT SUBSAMPLED
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
{0: instance Mouse J064, run 10 of Session class, 1: instance Mouse J064, run 11 of Session class, 2: instance Mouse J064, run 14 of Session class, 3: instance Mouse RL070, run 28 of Session class, 4: instance Mouse RL070, run 29 of Session class, 5:

In [13]:
#Figure 5h
##load data
res_dataset1 = xr.load_dataarray("../final_data/ML_LFA_nfac0,5_S1_TestHit_taupost_results_1000fits.nc")
res_dataset2 = xr.load_dataarray("../final_data/ML_LFA_nfac0,5_S1_TestMiss_taupost_results_1000fits.nc")
res_dataset3 = xr.load_dataarray("../final_data/ML_LFA_nfac0,5_S1_TestHitAndMiss_taupost_results_1000fits.nc")
res_dataset = xr.concat([res_dataset1, res_dataset2, res_dataset3], dim='trial_type')

res_S1_dataset = xr.load_dataarray("../final_data/ML_nfac0+5_HitAndMissResults_1000fits.nc")
res_S2_dataset = xr.load_dataarray("../final_data/ML_nfac0+5_HitAndMiss_S2Results_1000fits.nc")

##calc s
var_cc = res_dataset.sel(n_fact=5, activity_type="residual", trial_type="HitAndMiss", variable='var_cc').values
mean_var = res_dataset.sel(n_fact=5, activity_type="residual", trial_type="HitAndMiss", variable='mean_var').values
s = var_cc / mean_var

var_cc_S1 = res_S1_dataset.sel(n_fact=5, activity_type="residual", sample=0, trial_type="HitAndMiss", variable='var_cc').values
mean_var_S1 = res_S1_dataset.sel(n_fact=5, activity_type="residual", sample=0, trial_type="HitAndMiss", variable='mean_var').values
s_S1 = var_cc_S1 / mean_var_S1

var_cc_S2 = res_S2_dataset.sel(n_fact=5, activity_type="residual", sample=0, trial_type="HitAndMiss", variable='var_cc').values
mean_var_S2 = res_S2_dataset.sel(n_fact=5, activity_type="residual", sample=0, trial_type="HitAndMiss", variable='mean_var').values
s_S2 = var_cc_S2 / mean_var_S2

##calc R
def calc_R(mu_V, sigma_CC, N=50000):
    s = sigma_CC / mu_V
    
    return np.sqrt(1-np.sqrt(1/(1+N * s**2)))

R = calc_R(mean_var, var_cc, N=50000)

R_S1 = calc_R(mean_var_S1, var_cc_S1, N=50000)
R_S2 = calc_R(mean_var_S2, var_cc_S2, N=50000)

##plot
class CloseToOne(mscale.ScaleBase):
    name = 'close_to_one'

    def __init__(self, axis, **kwargs):
        mscale.ScaleBase.__init__(self, axis)
        self.nines = kwargs.get('nines', 10)

    def get_transform(self):
        return self.Transform(self.nines)

    def set_default_locators_and_formatters(self, axis):
        axis.set_major_locator(FixedLocator(
                np.array([1-10**(-k) for k in range(1+self.nines)])))
        axis.set_major_formatter(FixedFormatter(
                [str(1-10**(-k)) for k in range(1+self.nines)]))
        axis.set_minor_locator(FixedLocator(
                np.array([[1-2*10**(-k), 1-3*10**(-k), 1-4*10**(-k), 1-5*10**(-k), 1-6*10**(-k), 1-7*10**(-k), 1-8*10**(-k), 1-9*10**(-k)] for k in range(1+self.nines)]).ravel()))


    def limit_range_for_scale(self, vmin, vmax, minpos):
        return vmin, min(1 - 10**(-self.nines), vmax)

    class Transform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, nines):
            mtransforms.Transform.__init__(self)
            self.nines = nines

        def transform_non_affine(self, a):
            masked = ma.masked_where(a > 1-10**(-1-self.nines), a)
            if masked.mask.any():
                return -ma.log10(1-a)
            else:
                return -np.log10(1-a)

        def inverted(self):
            return CloseToOne.InvertedTransform(self.nines)

    class InvertedTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, nines):
            mtransforms.Transform.__init__(self)
            self.nines = nines

        def transform_non_affine(self, a):
            return 1. - 10**(-a)

        def inverted(self):
            return CloseToOne.Transform(self.nines)
        
mscale.register_scale(CloseToOne)

def ML_R_PlotSketch_InverseLog(smin=0.0001, smax=0.2, ax=False, N=50000, color='black', linestyle='-'):
    f_R = lambda s, N: np.sqrt(1-np.sqrt(1/(1+N * s**2)))

    s_range = np.linspace(smin, smax, 1000)
    
    if not ax:
        fig, ax = plt.subplots(figsize=(2,1.8), dpi=300)
    
    ax.plot(s_range, f_R(s_range, N), color=color, linestyle=linestyle)
    
    ax.set_xlim(0,0.2)
    #ax.set_ylim(0,1)
    ax.set_xticks([0,0.1,0.2])
    ax.set_yscale('close_to_one', nines=3)
    #ax.set_yticks([0.5,0.9,0.99,0.999])
    #ax.set_xscale('log')
    #ax.axhline(1, color='black', linestyle='--')
    
    ax.set_xlabel(r"$\sigma_{CC}$" + " (normalised)")
    ax.set_ylabel("recurrence " + r"$R$")

fig, ax = plt.subplots(figsize=(2,1.75), dpi=300)

mean_s_S1 = np.mean(s_S1)
mean_R_S1 = np.mean(R_S1)

mean_s_S2 = np.mean(s_S2)
mean_R_S2 = np.mean(R_S2)

ML_R_PlotSketch_InverseLog(ax=ax, N=50000)

plt.scatter(s_S2, R_S2, s=25, 
            c='none', edgecolor='silver', alpha=1, zorder=8, clip_on=False)

plt.scatter(s_S1, R_S1, s=25, 
            c='none', edgecolor='gold', alpha=1, zorder=7, clip_on=False)

plt.vlines(mean_s_S1-np.std(s_S1),ymin=0,ymax=mean_R_S1-np.std(R_S1), 
           color='gold', linestyle=':', linewidth=1.5)
plt.vlines(mean_s_S1,ymin=0,ymax=mean_R_S1, 
           color='gold', linewidth=1.5)
plt.vlines(mean_s_S1+np.std(s_S1),ymin=0,ymax=mean_R_S1+np.std(R_S1), 
           color='gold', linestyle=':', linewidth=1.5)

plt.vlines(mean_s_S2-np.std(s_S2),ymin=0,ymax=mean_R_S2-np.std(R_S2), 
           color='silver', linestyle=':', linewidth=1.5)
plt.vlines(mean_s_S2,ymin=0,ymax=mean_R_S2, 
           color='silver', linewidth=1.5)
plt.vlines(mean_s_S2+np.std(s_S2),ymin=0,ymax=mean_R_S2+np.std(R_S2), 
           color='silver', linestyle=':', linewidth=1.5)

plt.hlines(mean_R_S2-np.std(R_S2),xmin=0,xmax=mean_s_S2-np.std(s_S2), 
           color='silver', linestyle=':', linewidth=1.5)
plt.hlines(mean_R_S2,xmin=0,xmax=mean_s_S2, 
           color='silver', linewidth=1.5)
plt.hlines(mean_R_S2+np.std(R_S2),xmin=0,xmax=mean_s_S2+np.std(s_S2), 
           color='silver', linestyle=':', linewidth=1.5)

plt.hlines(mean_R_S1-np.std(R_S1),xmin=0,xmax=mean_s_S1-np.std(s_S1), 
           color='gold', linestyle=':', linewidth=1.5)
plt.hlines(mean_R_S1,xmin=0,xmax=mean_s_S1, 
           color='gold', linewidth=1.5)
plt.hlines(mean_R_S1+np.std(R_S1),xmin=0,xmax=mean_s_S1+np.std(s_S1), 
           color='gold', linestyle=':', linewidth=1.5)

plt.yticks([0,0.9,0.99,0.999])
plt.ylim(0,0.999)

plt.tight_layout()
fig.savefig("../final_reports/Figure5h.pdf", dpi=300, bbox_inches="tight")

In [14]:
#Figure 5i
##load data
res_dataset1 = xr.load_dataarray("../final_data/ML_nfac0-2_results_1000fits.nc")
res_dataset2 = xr.load_dataarray("../final_data/ML_nfac3-9_results_1000fits.nc")
res_dataset3 = xr.load_dataarray("../final_data/ML_nfac10-15_results_1000fits.nc")
res_dataset4 = xr.load_dataarray("../final_data/ML_nfac16-20_results_1000fits.nc")
res_dataset5 = xr.load_dataarray("../final_data/ML_nfac21-25_results_1000fits.nc")

res_dataset = xr.concat([res_dataset1, res_dataset2, res_dataset3, res_dataset4, res_dataset5], dim='n_fact')

##plot
fig = plt.figure(figsize=(7, 10))
re.plotting.figure_single_nfact_connecting_lines(res_dataset, fig, n_fact=5)
fig.tight_layout()
fig.savefig("../final_reports/Figure5i.pdf", dpi=300, bbox_inches="tight")

In [15]:
#Figure 5j
##load data
session_dict = re.load_data_prepost(remove_toosoon=True, label_urh_arm=True)

##calc
def fit_func(t, A, tau):
    return A*np.exp(-t/tau)

n_sessions = len(session_dict)

fig, ax = plt.subplots(n_sessions, 1, figsize=(5,25), dpi=300)

n_frames = session_dict['Mouse RL116, run 33'].shape[2]

testHit_delta = np.zeros((n_sessions, n_frames))
testMiss_delta = np.zeros((n_sessions, n_frames))

Hit_tau_hats = np.zeros(n_sessions)    
Miss_tau_hats = np.zeros(n_sessions) 

i_session = -1

for session_name, session in session_dict.items():
    i_session += 1
    targeted_neurons = np.any(session.is_target, axis=1)

    n_neurons = session.shape[0]

    pre_frames = session.where(session.frame_type.isin((-1)), drop=True).frame_num.values
    post_frames = session.where(session.frame_type.isin((1)), drop=True).frame_num.values

    testStim_trials = session.where(session.stim_type.isin((1)), drop=True).trial_num.values

    hitTrials = session.where(session.trial_type.isin(("hit")), drop=True).trial_num.values
    missTrials = session.where(session.trial_type.isin(("miss")), drop=True).trial_num.values
    HitAndMiss_trials = session.where(session.trial_type.isin(("hit", "miss")), drop=True).trial_num.values

    testHitTrials = np.intersect1d(testStim_trials, hitTrials)
    testMissTrials = np.intersect1d(testStim_trials, missTrials)
    testHitAndMissTrials = np.union1d(testHitTrials,testMissTrials)

    testHit_TargetNeuronAverages = np.zeros((n_neurons,n_frames))
    testHit_TargetNeuronAverages[:] = np.nan
    
    testMiss_TargetNeuronAverages = np.zeros((n_neurons,n_frames))
    testMiss_TargetNeuronAverages[:] = np.nan

    for i_neuron in np.arange(n_neurons)[targeted_neurons]:
        testHit_TargetNeuronAverages[i_neuron] = np.nanmean(session[i_neuron][testHitTrials[session.is_target[i_neuron,testHitTrials].values]], axis=0)
        testMiss_TargetNeuronAverages[i_neuron] = np.nanmean(session[i_neuron][testMissTrials[session.is_target[i_neuron,testMissTrials].values]], axis=0)

    testHit_delta[i_session] = np.nanmean(testHit_TargetNeuronAverages, axis=0) - np.nanmean(np.nanmean(testHit_TargetNeuronAverages, axis=0)[0:pre_frames[-1]])
    testMiss_delta[i_session] = np.nanmean(testMiss_TargetNeuronAverages, axis=0) - np.nanmean(np.nanmean(testMiss_TargetNeuronAverages, axis=0)[0:pre_frames[-1]])

    t = post_frames

    testHit_popt, _ = curve_fit(fit_func, t, 
                                testHit_delta[i_session][post_frames], 
                                p0=(1, 100),
                                maxfev = 10000)

    Hit_tau_hats[i_session] = testHit_popt[1]
    
    testMiss_popt, _ = curve_fit(fit_func, t, 
                                testMiss_delta[i_session][post_frames], 
                                p0=(1, 100),
                                maxfev = 10000)

    Miss_tau_hats[i_session] = testMiss_popt[1]

##plot
fig, ax = plt.subplots(figsize=(2.5,1.75), dpi=300)

i_session = -1

for session_name, session in session_dict.items():
    i_session += 1
    ax.plot(np.arange(0,8*30-3)/30-8,testHit_delta[i_session,:8*30-3], color=color_tt['hit'], alpha=1/n_sessions)  
    ax.plot(np.arange(0,8*30-3)/30-8,testMiss_delta[i_session,:8*30-3], color=color_tt['miss'], alpha=1/n_sessions)
    
    ax.plot(np.arange(8.25*30+3,14*30)/30-8,testHit_delta[i_session,int(8.25*30)+3:], color=color_tt['hit'], alpha=1/n_sessions)  
    ax.plot(np.arange(8.25*30+3,14*30)/30-8,testMiss_delta[i_session,int(8.25*30)+3:], color=color_tt['miss'], alpha=1/n_sessions)
    
ax.plot(np.arange(0,8*30-3)/30-8,np.mean(testHit_delta, axis=0)[:8*30-3], color=color_tt['hit'], label='Hit', linewidth=2)
ax.plot(np.arange(0,8*30-3)/30-8,np.mean(testMiss_delta, axis=0)[:8*30-3], color=color_tt['miss'], label='Miss', linewidth=2)

ax.plot(np.arange(8.25*30+3,14*30)/30-8,np.mean(testHit_delta, axis=0)[int(8.25*30)+3:], color=color_tt['hit'])
ax.plot(np.arange(8.25*30+3,14*30)/30-8,np.mean(testMiss_delta, axis=0)[int(8.25*30)+3:], color=color_tt['miss'])

ax.axvspan(xmin=0, xmax=0.25, color=color_tt['photostim'], alpha=0.5)

ax.set_xlim(-2,6)
ax.set_ylim(-0.1,0.3)

ax.set_xticks([-2,0,2,4,6])

ax.set_ylabel(r'Target $\Delta$F/F')

ax.legend(loc='upper right', fontsize=8)

ax.set_xlabel('Time (s)')

plt.tight_layout()
plt.savefig("../final_reports/Figure5j.pdf")

long post time
long post time
long post time
long post time
long post time
Mouse RL070, run 29  registered no-lick hit. changed to too soon
long post time
long post time
Mouse RL117, run 29  registered no-lick hit. changed to too soon
Mouse RL117, run 29  registered no-lick hit. changed to too soon
long post time
long post time
long post time
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
long post time
ALERT SESSIONS NOT SUBSAMPLED
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time


  0%|          | 0/11 [00:00<?, ?it/s]

{0: instance Mouse J064, run 10 of Session class, 1: instance Mouse J064, run 11 of Session class, 2: instance Mouse J064, run 14 of Session class, 3: instance Mouse RL070, run 28 of Session class, 4: instance Mouse RL070, run 29 of Session class, 5: instance Mouse RL117, run 26 of Session class, 6: instance Mouse RL117, run 29 of Session class, 7: instance Mouse RL117, run 30 of Session class, 8: instance Mouse RL123, run 22 of Session class, 9: instance Mouse RL116, run 32 of Session class, 10: instance Mouse RL116, run 33 of Session class}
URH and ARM trials have been labelled


100%|██████████| 11/11 [00:06<00:00,  1.67it/s]


In [16]:
#Figure 5k
##load data
session_dict = re.load_data_prepost(remove_toosoon=True, label_urh_arm=True)

##calc
def fit_func(t, A, tau):
    return A*np.exp(-t/tau)

n_sessions = len(session_dict)

fig, ax = plt.subplots(n_sessions, 1, figsize=(5,25), dpi=300)

n_frames = session_dict['Mouse RL116, run 33'].shape[2]

testHit_delta = np.zeros((n_sessions, n_frames))
testMiss_delta = np.zeros((n_sessions, n_frames))

Hit_tau_hats = np.zeros(n_sessions)    
Miss_tau_hats = np.zeros(n_sessions) 

i_session = -1

for session_name, session in session_dict.items():
    i_session += 1
    targeted_neurons = np.any(session.is_target, axis=1)

    n_neurons = session.shape[0]

    pre_frames = session.where(session.frame_type.isin((-1)), drop=True).frame_num.values
    post_frames = session.where(session.frame_type.isin((1)), drop=True).frame_num.values

    testStim_trials = session.where(session.stim_type.isin((1)), drop=True).trial_num.values

    hitTrials = session.where(session.trial_type.isin(("hit")), drop=True).trial_num.values
    missTrials = session.where(session.trial_type.isin(("miss")), drop=True).trial_num.values
    HitAndMiss_trials = session.where(session.trial_type.isin(("hit", "miss")), drop=True).trial_num.values

    testHitTrials = np.intersect1d(testStim_trials, hitTrials)
    testMissTrials = np.intersect1d(testStim_trials, missTrials)
    testHitAndMissTrials = np.union1d(testHitTrials,testMissTrials)

    testHit_TargetNeuronAverages = np.zeros((n_neurons,n_frames))
    testHit_TargetNeuronAverages[:] = np.nan
    
    testMiss_TargetNeuronAverages = np.zeros((n_neurons,n_frames))
    testMiss_TargetNeuronAverages[:] = np.nan

    for i_neuron in np.arange(n_neurons)[targeted_neurons]:
        testHit_TargetNeuronAverages[i_neuron] = np.nanmean(session[i_neuron][testHitTrials[session.is_target[i_neuron,testHitTrials].values]], axis=0)
        testMiss_TargetNeuronAverages[i_neuron] = np.nanmean(session[i_neuron][testMissTrials[session.is_target[i_neuron,testMissTrials].values]], axis=0)

    testHit_delta[i_session] = np.nanmean(testHit_TargetNeuronAverages, axis=0) - np.nanmean(np.nanmean(testHit_TargetNeuronAverages, axis=0)[0:pre_frames[-1]])
    testMiss_delta[i_session] = np.nanmean(testMiss_TargetNeuronAverages, axis=0) - np.nanmean(np.nanmean(testMiss_TargetNeuronAverages, axis=0)[0:pre_frames[-1]])

    t = post_frames

    testHit_popt, _ = curve_fit(fit_func, t, 
                                testHit_delta[i_session][post_frames], 
                                p0=(1, 100),
                                maxfev = 10000)

    Hit_tau_hats[i_session] = testHit_popt[1]
    
    testMiss_popt, _ = curve_fit(fit_func, t, 
                                testMiss_delta[i_session][post_frames], 
                                p0=(1, 100),
                                maxfev = 10000)

    Miss_tau_hats[i_session] = testMiss_popt[1]

##plot
fig, ax = plt.subplots(figsize=(2.5,1.75), dpi=300)

i_session = 2

ax.plot(np.arange(0,8*30-3)/30-8, testHit_delta[i_session,:8*30-3], color=color_tt['hit'], linewidth=0.5)  
ax.plot(np.arange(0,8*30-3)/30-8, testMiss_delta[i_session,:8*30-3], color=color_tt['miss'], linewidth=0.5)

ax.plot(np.arange(8.25*30+3,14*30)/30-8, testHit_delta[i_session,int(8.25*30)+3:], color=color_tt['hit'], linewidth=0.5)  
ax.plot(np.arange(8.25*30+3,14*30)/30-8, testMiss_delta[i_session,int(8.25*30)+3:], color=color_tt['miss'], linewidth=0.5)

t = post_frames

testHit_popt, _ = curve_fit(fit_func, t, 
                            testHit_delta[i_session][post_frames], 
                            p0=(1, 100),
                            maxfev = 10000)


testMiss_popt, _ = curve_fit(fit_func, t, 
                            testMiss_delta[i_session][post_frames], 
                            p0=(1, 100),
                            maxfev = 10000)

ax.plot(t/30-8, fit_func(t, *testHit_popt),
                   color=color_tt['hit'], linestyle='--', 
                   label=r'$\tau_{post}$'+'= {:.1f} s'.format(testHit_popt[1]/30))   

ax.plot(t/30-8, fit_func(t, *testMiss_popt),
                   color=color_tt['miss'], linestyle='--', 
                   label=r'$\tau_{post}$'+'= {:.1f} s'.format(testMiss_popt[1]/30))    

ax.axvspan(xmin=0, xmax=0.25, color=color_tt['photostim'], alpha=0.5)

ax.set_xlim(-2,6)
ax.set_ylim(-0.1,0.3)

ax.set_xticks([-2,0,2,4,6])

ax.set_ylabel(r'Target $\Delta$F/F')

ax.legend(loc='upper right', fontsize=8)

ax.set_xlabel('Time (s)')

plt.tight_layout()
plt.savefig("../final_reports/Figure5k.pdf")

long post time
long post time
long post time
long post time
long post time
Mouse RL070, run 29  registered no-lick hit. changed to too soon
long post time
long post time
Mouse RL117, run 29  registered no-lick hit. changed to too soon
Mouse RL117, run 29  registered no-lick hit. changed to too soon
long post time
long post time
long post time
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
Mouse RL116, run 32  registered no-lick hit. changed to too soon
long post time
ALERT SESSIONS NOT SUBSAMPLED
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time
long post time


  0%|          | 0/11 [00:00<?, ?it/s]

{0: instance Mouse J064, run 10 of Session class, 1: instance Mouse J064, run 11 of Session class, 2: instance Mouse J064, run 14 of Session class, 3: instance Mouse RL070, run 28 of Session class, 4: instance Mouse RL070, run 29 of Session class, 5: instance Mouse RL117, run 26 of Session class, 6: instance Mouse RL117, run 29 of Session class, 7: instance Mouse RL117, run 30 of Session class, 8: instance Mouse RL123, run 22 of Session class, 9: instance Mouse RL116, run 32 of Session class, 10: instance Mouse RL116, run 33 of Session class}
URH and ARM trials have been labelled


100%|██████████| 11/11 [00:08<00:00,  1.26it/s]


In [17]:
#Figure 5l
##load data

res_dataset1 = xr.load_dataarray("../final_data/ML_LFA_nfac0,5_S1_TestHit_taupost_results_1000fits.nc")
res_dataset2 = xr.load_dataarray("../final_data/ML_LFA_nfac0,5_S1_TestMiss_taupost_results_1000fits.nc")
res_dataset3 = xr.load_dataarray("../final_data/ML_LFA_nfac0,5_S1_TestHitAndMiss_taupost_results_1000fits.nc")

res_dataset = xr.concat([res_dataset1, res_dataset2, res_dataset3], dim='trial_type')

taupost_HitAndMiss = res_dataset.sel(activity_type='residual', variable='taupost', n_fact=5, trial_type='HitAndMiss')
var_cc_HitAndMiss = res_dataset.sel(activity_type='residual', variable='var_cc', n_fact=5, trial_type='HitAndMiss')
mean_var_HitAndMiss = res_dataset.sel(activity_type='residual', variable='mean_var', n_fact=5, trial_type='HitAndMiss')


##calc
def calc_R(mu_V, sigma_CC, N=50000):
    s = sigma_CC / mu_V
    
    return np.sqrt(1-np.sqrt(1/(1+N * s**2)))

R_HitAndMiss = calc_R(mean_var_HitAndMiss, var_cc_HitAndMiss, N=50000)

###mask failed fits
R_masked = np.ma.masked_outside(R_HitAndMiss.values, 0, 1.0)
tau_masked = np.ma.masked_outside(taupost_HitAndMiss.values, 15, 300)

##fit experiment-theory relation
def fit_func(R, tau_c):
    return tau_c / (1-R)
R_hat = R_masked.mean(axis=1)
tau_hat = tau_masked.mean(axis=1) / 30

popt_HitAndMiss, _ = curve_fit(fit_func, R_hat, tau_hat, p0 = [100])
gradient, intercept, r_value, p_value, std_err = scipy.stats.linregress(R_hat, tau_hat)

R_space = np.linspace(0.01,0.999,1000)

exp_residuals = tau_hat - fit_func(R_hat, *popt_HitAndMiss)
exp_ssres = np.sum(exp_residuals**2)

linear_residuals = tau_hat - (gradient*R_hat + intercept)
linear_ssres = np.sum(linear_residuals**2)

sstot = np.sum((tau_hat - np.mean(tau_hat))**2)

exp_rsquared = 1 - (exp_ssres/sstot)
linear_rsquared = 1 - (linear_ssres/sstot)

##plot
class CloseToOne(mscale.ScaleBase):
    name = 'close_to_one'

    def __init__(self, axis, **kwargs):
        mscale.ScaleBase.__init__(self, axis)
        self.nines = kwargs.get('nines', 5)

    def get_transform(self):
        return self.Transform(self.nines)

    def set_default_locators_and_formatters(self, axis):
        axis.set_major_locator(FixedLocator(
                np.array([1-10**(-k) for k in range(1+self.nines)])))
        axis.set_major_formatter(FixedFormatter(
                [str(1-10**(-k)) for k in range(1+self.nines)]))
        
        axis.set_minor_locator(FixedLocator(
                np.array([[1-2*10**(-k), 1-3*10**(-k), 1-4*10**(-k), 1-5*10**(-k), 1-6*10**(-k), 1-7*10**(-k), 1-8*10**(-k), 1-9*10**(-k)] for k in range(1+self.nines)]).ravel()))


    def limit_range_for_scale(self, vmin, vmax, minpos):
        return vmin, min(1 - 10**(-self.nines), vmax)

    class Transform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, nines):
            mtransforms.Transform.__init__(self)
            self.nines = nines

        def transform_non_affine(self, a):
            masked = ma.masked_where(a > 1-10**(-1-self.nines), a)
            if masked.mask.any():
                return -ma.log10(1-a)
            else:
                return -np.log10(1-a)

        def inverted(self):
            return CloseToOne.InvertedTransform(self.nines)

    class InvertedTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, nines):
            mtransforms.Transform.__init__(self)
            self.nines = nines

        def transform_non_affine(self, a):
            return 1. - 10**(-a)

        def inverted(self):
            return CloseToOne.Transform(self.nines)
        
mscale.register_scale(CloseToOne)

fig, ax = plt.subplots(dpi=300, figsize=(2.15,1.75))

plt.scatter(R_hat, tau_hat, 
            s=20, c='none', edgecolor='red', linewidth=2,
            label='data')
plt.plot(R_space, fit_func(R_space, *popt_HitAndMiss), 
         color='grey', linestyle='-', linewidth=2,
         label='theory')

plt.ylim(0,10)
plt.xscale('close_to_one')
plt.xlim(0,0.999)

plt.xlabel('recurrence $R$')
plt.ylabel(r'$\tau_{post}$ (s)')

plt.tight_layout()
plt.savefig("../final_reports/Figure5l.pdf")

In [18]:
#Figure 5m
##load data
hit_tau_hats = np.load('../final_data/PulseBasedRecurrencey_hit_tau_hats_1000MaxPerms.npy')
miss_tau_hats = np.load('../final_data/PulseBasedRecurrencey_miss_tau_hats_1000MaxPerms.npy')

hit_medians = np.nanmedian(hit_tau_hats, axis=1)
miss_medians = np.nanmedian(miss_tau_hats, axis=1)

##plot
stats = pg.wilcoxon(miss_medians/30,hit_medians/30)

n_sessions = 11

p_val = stats['p-val']['Wilcoxon']

fig, ax = plt.subplots(figsize=(2.25,1.75), dpi=300)

ax.set_ylim(0,10)

tmp_df = pd.DataFrame(
        {
            "taupost":np.concatenate([miss_medians/30,hit_medians/30]),
            "trial_type": ["miss"] * n_sessions + ["hit"] * n_sessions,
        }
    )
    
tmp_df['x'] = np.random.randn(2*n_sessions) * 0.1

ax.plot(tmp_df[tmp_df['trial_type']=='miss']['x'], tmp_df[tmp_df['trial_type']=='miss']['taupost'],
       '.', color='k',#('k' if bool_sign else 'grey'), 
                        markersize=10)
ax.plot(1+tmp_df[tmp_df['trial_type']=='hit']['x'], tmp_df[tmp_df['trial_type']=='hit']['taupost'],
       '.', color='k',#('k' if bool_sign else 'grey'), 
                        markersize=10)

ax.plot([tmp_df[tmp_df['trial_type']=='miss']['x'], 1+tmp_df[tmp_df['trial_type']=='hit']['x']],
        [tmp_df[tmp_df['trial_type']=='miss']['taupost'], tmp_df[tmp_df['trial_type']=='hit']['taupost']],
            c='k', alpha=0.7)

ax.set_title("p = {:.3f}".format(p_val), fontsize=10)

ax.set_xlabel('')
ax.set_xticks([0, 1])
ax.set_xticklabels(['miss', 'hit'])

ax.set_ylabel(r"$\tau_{post}$ (s)")

ax.tick_params(bottom=False)
sns.despine()

plt.tight_layout()

plt.savefig('../final_reports/Figure5m.pdf')