In [None]:
# glm related imports
from common_imports import *
from my_imports import *
from train_model import *
from glmfits_utils import *
from simulations import *
from basis_kernels import *

# data loading related imports
from helpers.phys_helpers import get_sortCells_for_rat, datetime, split_trials, flatten_list, savethisfig
from helpers.physdata_preprocessing import load_phys_data_from_Cell
from helpers.rasters_and_psths import get_neural_activity

# logistic decoding related imports
sys.path.insert(1, '../../../figure_code/')
from figure2.fig2_helpers.logisticdecoding import run_logistic_decoding, equalize_trials
from figure2.fig2_helpers.evidencedecoding import get_evidence, run_evidence_decoding

from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.model_selection import KFold, cross_val_predict

# figure plotting settings
# plt.style.use('seaborn-deep')
plt.rcParams['figure.dpi']= 150
plt.rcParams['figure.figsize'] = [4, 3]
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.major.size'] = 2 
plt.rcParams['ytick.major.size'] = 2 
# plt.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.RdGy(np.linspace(0,1,8)))
plt.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.Spectral(np.linspace(0,1,8)))
plt.rcParams['figure.max_open_warning'] = 0


sns.set_context('paper')



In [None]:
def get_CV_logll(
    GLM_DIR,
    file,
    region_X,
    region_Y):
    
    data = np.load(GLM_DIR + file, allow_pickle=True).item()
    this_fit = data['model_dict'][region_X + '_' + region_Y]
    val_loss_per_spike = []
    for repeat in range(data['p']['num_repeats']):
        val_loss = []
        for fold in range(1, data['p']['num_folds']+1):
            val_loss.append(-1*this_fit[repeat][fold]['validation_loss'])
        val_loss_per_spike.append(np.mean(val_loss))
        
    meta = dict()
    meta['mean_val_loss_per_spike'] = np.mean(val_loss_per_spike)
    meta['sem_val_loss_per_spike'] = np.std(val_loss_per_spike)/np.sqrt(len(val_loss_per_spike))
    meta['filename'] = data['filename']
    meta['fit_filename'] = file
    return meta


# GET SUMMARY
RRGLM_DIR = SPEC.RESULTDIR + 'neural_rr_GLM/'
ranks = range(1,6)
summary = pd.DataFrame()
for (region_X, region_Y) in zip(['FOF', 'ADS'], ['ADS', 'FOF']):
    for rank in ranks:
        rrglm_files = sorted([fn for fn in os.listdir(RRGLM_DIR) if '.npy' in fn])
        rrglm_files = [fn for fn in rrglm_files if f'rank{rank}_25-05-2024' in fn or f'rank{rank}_26-05-2024' in fn or f'rank{rank}_27-05-2024' in fn]
        
        for filename in rrglm_files:
            meta = get_CV_logll(RRGLM_DIR, filename, region_X,region_Y)
            meta['rank'] = rank
            meta['fit_type'] = [region_X + '_' + region_Y] 
            summary = pd.concat([summary, pd.DataFrame(meta)], ignore_index = True)

### find optimal ranks

In [None]:
optimal_rank_dict = dict()
filenames = np.unique(summary['filename'])

for f, fit_type in enumerate(np.unique(summary["fit_type"])):
    optimal_rank_dict[fit_type] = dict()
    for i, file in enumerate(filenames):  
        optimal_rank_dict[fit_type][file] = dict()
        
        this_df = summary.query("filename == '{}' & fit_type == '{}'".format(file, fit_type) )
        ranks = np.unique(this_df["rank"])
        y = np.array([this_df[this_df["rank"] == rank]['mean_val_loss_per_spike'].values[0] for rank in ranks])
        
        optimal_rank = min(4, ranks[np.argmax(y)])
        fit_filename = this_df.query("rank == {}".format(optimal_rank))['fit_filename'].values[0]
        optimal_rank_dict[fit_type][file]['rank'] = optimal_rank
        optimal_rank_dict[fit_type][file]['fit_filename'] = fit_filename
        


### plot logll as a function of rank

In [None]:
filenames = np.unique(summary['filename'])[:-1]
fig, axs = plt.subplots(2,2, figsize = (4,3), sharex = True)

for f, fit_type in enumerate(np.unique(summary["fit_type"])):
    
    y_overall = []
    for i, file in enumerate(filenames):    
        this_df = summary.query("filename == '{}' & fit_type == '{}'".format(file, fit_type) )
        ranks = np.unique(this_df["rank"])
        
        y = np.array([this_df[this_df["rank"] == rank]['mean_val_loss_per_spike'].values[0] for rank in ranks])
        yerr = [this_df[this_df["rank"] == rank]['sem_val_loss_per_spike'].values[0] for rank in ranks]
        min_y, max_y = np.min(y), np.max(y)
        y = (y - min_y)/(max_y - min_y)
        y_overall.append(y)
        axs[0,f].errorbar(ranks, y, yerr=yerr, fmt='-o', alpha = 0.5, ms = 2)
        axs[0,f].set_xticks(range(1,6))
        
        if fit_type == 'ADS_FOF':
            axs[0,f].set_title("ADS $\\rightarrow$ FOF")
        else:
            axs[0,f].set_title("FOF $\\rightarrow$ ADS")
        
    y_overall_mean = np.mean(y_overall, 0)
    y_overall_sem = np.std(y_overall, 0)/np.sqrt(len(filenames))
    axs[0,f].errorbar(ranks, y_overall_mean, yerr = y_overall_sem, fmt = '-o', c = 'k', ms = 3, zorder = 100)
    
    # plot picked ranks
    picked_ranks = [optimal_rank_dict[fit_type][file]['rank'] for file in filenames]
    axs[1,f].hist(picked_ranks, color = 'grey', edgecolor = 'k')
    axs[1,f].set_ylim(0,10)

fig.text(0.6, 0.0, 'Rank of comm. subspace (CS)', ha='center')
axs[0,0].set_ylabel('Norm. C.V. logll \nper spike $\pm$ SEM')
axs[1,0].set_ylabel('Number of sessions \n with rank')

fig.align_ylabels(axs[:, 0])
plt.tight_layout()
sns.despine()
savethisfig(SPEC.FIGUREDIR + "figure2/", 'figure2_CSrank')


### Plot r2 distributions across datasets for optimal ranks

In [None]:
def get_bestfit_params(GLM_DIR, file, region_X, region_Y):

    data = np.load(GLM_DIR + file, allow_pickle=True).item()
    
    prm_meta = dict()
    prm_meta['filename'] = data['filename']
    prm_meta['covar_dict'] = data['covar_dict'] 
    prm_meta['rank_dict'] = data['rank_dict']
    prm_meta['p'] = data['p']
    
    data = data['model_dict'][region_X + '_' + region_Y]
    prm_meta['best_reg_param'] = data['best_reg_param']
    prm_meta['best_avg_val_loss'] = data['best_avg_val_loss']
    prm_meta['loss'] = data['loss']

    return data['state_dict'], prm_meta



def get_data_for_glm(filename, p, region_X, region_Y):
    
    _, this_datadir = get_sortCells_for_rat(filename[:4], SPEC.DATADIR)
    df_trial, df_cell, _ = load_phys_data_from_Cell(this_datadir + os.sep + filename)
    df_cell = df_cell[df_cell.stim_fr > p['fr_thresh']].reset_index()

    X = get_X(df_trial, p)
    Y = get_Y(df_cell, df_trial, p, region = region_Y)
    
    p_input = deepcopy(p)
    p_input['window'] = [x - p_input['binsize'] for x in p_input['window']]
    Y_input = get_Y(df_cell, df_trial, p_input, region = region_X)
    
    return X, Y_input, Y, df_trial, df_cell



def load_glm_with_params(this_glm, state_dict):
    
    model_var_list = list(state_dict)
    model_var_list = [item for item in model_var_list if (item != 'bias' and item != 'rr_basis_w')]
    model_variables = [item.replace('.weight', '') for item in model_var_list]
    for val, val_id in zip(model_variables, model_var_list):
        getattr(this_glm, val).weight.data = to_t(state_dict[val_id])
    this_glm.bias = nn.Parameter(to_t(state_dict['bias']))
    this_glm.rr_basis_w = nn.Parameter(to_t(state_dict['rr_basis_w']))

    return this_glm



def load_glm_model(GLM_DIR, file, data, region_X, region_Y, num_X, num_Y):
    
    state_dict, data = get_bestfit_params(GLM_DIR, file, region_X, region_Y)
    data['rank_dict']['num_inputs'] = num_X
    this_glm = glm_rr(
            covar_dict = data['covar_dict'],
            num_Y = num_Y,
            rank_dict = data['rank_dict'],
            dt = data['p']['dt']).to(device)
    this_glm = load_glm_with_params(this_glm, state_dict)
    return this_glm 




def get_CV_predictions(
    GLM_DIR, 
    file, 
    region_X,
    region_Y,
    psth_filter_w = 75, 
    psth_filter_type = 'gaussian'):
    
    data = np.load(GLM_DIR + file, allow_pickle=True).item()
    this_fit = data['model_dict'][region_X + '_' + region_Y][0]
    
    X, Y_input, Y, df_trial, df_cell = get_data_for_glm(
        data['filename'], 
        data['p'], 
        region_X,
        region_Y)
    df_cell = df_cell.loc[df_cell['region'] == region_Y].reset_index(drop = True)
    dataset = glmRRDataset(X,Y_input,Y)
    
    this_glm = load_glm_model(GLM_DIR, file, data, region_X, region_Y, Y_input.shape[2], Y.shape[2])
    
    yhat = np.nan * np.zeros(np.shape(Y))
    for fold in range(1,data['p']['num_folds']+1):
        this_glm = load_glm_with_params(this_glm, this_fit[fold]['state_dict'])
        this_dataset = dataset[this_fit[fold]['val_index']]
        preds = this_glm(
            this_dataset['stimulus'], 
            this_dataset['input_spikes'], 
            this_dataset['spikes'])
        yhat[this_fit[fold]['val_index'], :, :] = from_t(preds)
        
    p_psth = copy.deepcopy(data['p'])
    p_psth['filter_w'] = psth_filter_w
    p_psth['filter_type'] = psth_filter_type
    y =  get_Y(df_cell, df_trial, p_psth, region = region_Y)
    
    meta = dict()
    meta['cell_ID'] = df_cell['cell_ID']
    meta['firing_rate'] = np.nanmean(y, (0,1))/data['p']['dt']
    meta['region'] = df_cell['region']
    meta['auc'] = df_cell['auc']
    meta['side_pref'] = df_cell['side_pref']
    meta['side_pref_p'] = df_cell['side_pref_p']
    meta['stim_fr'] = df_cell['stim_fr']
    meta['filename'] = data['filename']
    meta['ratname'] = data['filename'][:4]
    meta['fit_filename'] = file

    return yhat, y, df_trial, meta


    
def compute_R2(yhat, y, df_trial, split_by = 'pokedR'):
    
    if split_by is not None:
        splits = split_trials(df_trial, split_by)
    else:
        splits = None
        
    this_r2 = []
    for cell_num in range(y.shape[2]):
        this_y = np.squeeze(y[:,:,cell_num])
        this_yhat = np.squeeze(yhat[:,:,cell_num])
        this_r2.append(compute_R2_travg(this_yhat, this_y, splits))
        
    return np.array(this_r2)


def compute_R2_travg(yhat, y, splits):
    
    if splits is not None:
        r2_num, r2_den = 0., 0.
        conds = list(splits)
        mean_PSTH = np.nanmean(np.vstack([y[splits[c],:] for c in conds]), axis = 0)
        for c in conds:
            cond_mean = np.nanmean(y[splits[c],:], axis = 0)
            pred_cond_mean = np.nanmean(yhat[splits[c],:], axis = 0)
            r2_num += np.nansum((cond_mean - pred_cond_mean)**2)
            r2_den += np.nansum((mean_PSTH - cond_mean)**2)
    else:
        meanfr = np.nanmean(y)
        pred_meanfr = np.nanmean(yhat)
        meanPSTH = np.nanmean(y, axis = 0)
        pred_meanPSTH = np.nanmean(yhat, axis = 0)
        r2_num = np.nansum((meanPSTH - pred_meanPSTH)**2)
        r2_den = (meanfr - pred_meanfr)**2

    return (1 - (r2_num / r2_den)) 


def plot_R2_vs_firing_rate(data):
        
    col_pal = {tuple(SPEC.COLS[key]) for key in list(SPEC.COLS)}
    data.r2 = data.r2.clip(-0.2, 1)
    g = sns.JointGrid(ratio = 2, height = 6)
    sns.scatterplot(data = data, x = "firing_rate", y = "r2", hue = "region", 
                    palette = col_pal, s=20, alpha = .5, ax=g.ax_joint)
    g.ax_joint.set_ylim([-0.1,1.05])
    g.ax_joint.axhline(0., c= [0.6, 0.6, 0.6], ls = '--', zorder = 0)
    sns.histplot(
        data = data, y = "r2", hue = "region", 
        stat = "probability", common_norm = False, binwidth = 0.05, 
        palette = col_pal, binrange = [-0.2, 1],
        ax=g.ax_marg_y, legend = False)
    sns.histplot(
        data = data, x = "firing_rate", hue = "region", 
        stat = "probability", common_norm = False, bins = 30,
        palette = col_pal, ax=g.ax_marg_x, legend = False);
    g.ax_joint.set_xlabel("Firing Rate (spk/s)", fontsize = 14)
    g.ax_joint.set_ylabel("Cross validated R2", fontsize = 14)
    

In [None]:
RRGLM_DIR = SPEC.RESULTDIR + 'neural_rr_GLM/'

psth_filter_w = 75
psth_filter_type = 'gaussian'
split_by = 'pokedR'
summary = pd.DataFrame()

for (region_X, region_Y) in zip(['FOF', 'ADS'], ['ADS', 'FOF']):
    print(region_X, region_Y)
    fit_type = "{}_{}".format(region_X, region_Y)
    for data_file in list(optimal_rank_dict[fit_type]):
        print(optimal_rank_dict[fit_type][data_file]["fit_filename"])
        yhat, y, df_trial, meta = get_CV_predictions(
            RRGLM_DIR, 
            optimal_rank_dict[fit_type][data_file]['fit_filename'], 
            region_X,
            region_Y,
            psth_filter_w = psth_filter_w, 
            psth_filter_type = psth_filter_type)
        
        meta['r2'] = compute_R2(yhat, y, df_trial, split_by = 'pokedR')
        meta['rank'] = np.repeat(rank, y.shape[2])
        meta['fit_type'] = [region_X + '_' + region_Y] * y.shape[2]
        summary = pd.concat([summary, pd.DataFrame(meta)], ignore_index = True)


In [None]:
plot_R2_vs_firing_rate(summary)
savethisfig(SPEC.FIGUREDIR + "figure2/", 'figure2_rrglm_R2')


### Print some stats

In [None]:
# fraction of neurons with R2< 0

for region in np.unique(summary.region):
    this_df = summary.query("region == '{}'".format(region))
    
    frac =  sum((this_df.r2 < 0) == True)/len(this_df.r2)
    print("\n\nFraction less than 0 R2: {}, {}".format(region, frac))
    
    this_df = summary.query("region == '{}' and r2 >= 0".format(region))
    meanR2 =  np.mean(this_df.r2)
    semR2 = np.std(this_df.r2)/np.sqrt(len(this_df.r2))
    print("\nMeanR2 greater than 0 R2: {}, {} Â± {}".format(region, meanR2, semR2))
    
    medianR2 =  np.median(this_df.r2)
    print("MedianR2 greater than 0 R2: {}, {}".format(region, medianR2))
    


### run linear evidence decoding

In [None]:
def get_evidence(ntpts_per_trial, df_trial, p, ev_type="delta"):
    """_summary_

    Args:
        ntpts_per_trial (_type_): _description_
        df_trial (_type_): _description_
        p (_type_): _description_
        ev_type (str, optional): _description_. Defaults to "delta".

    Returns:
        _type_: _description_
    """
    
    print("aligning to clicks on ignoring definition in p")
    max_tpts = np.max(ntpts_per_trial)
    evidence = np.nan * np.zeros((len(df_trial), max_tpts))
    for tr, ntpts in zip(range(len(df_trial)), ntpts_per_trial):
        edges = np.arange(p['start_time'][0], (ntpts+1)*p['binsize'], p['binsize'])*0.001
        counts_L, _ = np.histogram(df_trial['leftBups'][tr] - df_trial['clicks_on'][tr], edges)
        counts_R, _ = np.histogram(df_trial['rightBups'][tr] - df_trial['clicks_on'][tr], edges)
        if ev_type == "delta":
            cumsum_data = np.cumsum(counts_R - counts_L)
        elif ev_type == "right":
            cumsum_data = np.cumsum(counts_R)
        elif ev_type == "left":
            cumsum_data = np.cumsum(counts_L)

        # Pad with zeros at the beginning if needed
        if len(cumsum_data) < ntpts:
            padding_length = ntpts - len(cumsum_data)
            padded_data = np.concatenate([np.zeros(padding_length), cumsum_data])
            evidence[tr, :ntpts] = padded_data
        elif len(cumsum_data) > ntpts:
            # Truncate if longer than ntpts
            evidence[tr, :ntpts] = cumsum_data[:ntpts]
        else:
            # Exact length match
            evidence[tr, :ntpts] = cumsum_data
    return evidence


In [None]:
RRGLM_DIR = SPEC.RESULTDIR + 'neural_rr_GLM/'

p = dict()
p['ratnames'] = SPEC.RATS
p['regions'] = SPEC.REGIONS
p['cols'] = SPEC.COLS
p['fr_thresh'] = 1.0  # firing rate threshold for including neurons
p['stim_thresh'] = 0.0  # stimulus duration threshold for including trials
p['align_to'] = ['clicks_on_delayed']
p['align_name'] = ['clickson_masked']
p['pre_mask'] = [None]
p['post_mask'] = ['clicks_off_delayed']
p['start_time'] = [-100]
p['end_time'] = [1100] # in ms   
p['trial_type'] = ['all']
p['Tshift'] = 100 # delay for stimulus (in ms)
p['binsize'] = 50  # in ms
p['filter_type'] = 'gaussian'
p['filter_w'] = 75  # in ms
p['Cs'] = np.logspace(-8,8,400)  # cross-validation parameter
p['nfolds'] = 10    # number of folds for cross-validation
p['n_repeats'] = 10 # number of repeats for cross-validation


for fit_type in list(optimal_rank_dict):
    
    region_x, region_y = fit_type.split('_')

    for data_file in list(optimal_rank_dict[fit_type]):

        fit_file = optimal_rank_dict[fit_type][data_file]['fit_filename']
        data = np.load(RRGLM_DIR + fit_file, allow_pickle=True).item()
        p_rrglm = data['p']
        p_rrglm['n_repeats'] = p['n_repeats']
        p_rrglm['nfolds'] = p['nfolds']
        p_rrglm['Cs'] = [0]
        p_rrglm['slide_t'] = 50
        p['binsize'] = p_rrglm['binsize']


        rat = data['filename'][:4]
        print("\n\n\n\n===== RAT: {} =====".format(rat))
        files, this_datadir = get_sortCells_for_rat(rat, SPEC.DATADIR)
        df_trial, df_cell, _ = load_phys_data_from_Cell(this_datadir + os.sep + data['filename'])
        df_trial = df_trial[df_trial['stim_dur_s_actual'] >= p['stim_thresh']].reset_index(drop = True)
        df_trial = equalize_trials(df_trial, 'pokedR')

        # not equalizing neurons across regions
        df_cell = df_cell[df_cell['stim_fr'] >= p['fr_thresh']].reset_index(drop = True)

        df_trial['clicks_on_delayed'] = df_trial['clicks_on'] + 1e-3*p['Tshift']
        df_trial['clicks_off_delayed'] = df_trial['clicks_off'] + 1e-3*p['Tshift']

        SAVEDIR = RRGLM_DIR + 'decoding/'
        fname = SAVEDIR + 'params' + datetime()[5:] + ".npy"
        np.save(fname, p)

        summary = dict()
        fname = SAVEDIR + 'decoding_' + fit_type + '_' + data_file[:-4] + datetime()[5:] + '.npy'


        ### DECODE FROM WHOLE REGION
        summary[region_y] = dict()
        print("\n\nDecoding from region: {} =====".format(region_y))
        X, ntpts_per_trial = get_neural_activity(df_cell, df_trial, region_y, p, 0)
        evidence = get_evidence(ntpts_per_trial, df_trial, p)
        
        summary[region_y] = run_evidence_decoding(X, evidence, ntpts_per_trial, p)
            
        ### DECODE FROM GLM
        model_var_list = list(data['model_dict'][fit_type]['state_dict'])
        model_var_list = [item for item in model_var_list if (item != 'bias' and item != 'rr_basis_w')]
        model_variables = [item.replace('.weight', '') for item in model_var_list]

        X = get_X(df_trial, p_rrglm)

    
        print("\n\nDecoding from CS: {} =====".format(fit_type))
        summary[fit_type] = dict()
        Y = get_Y(df_cell, df_trial, p_rrglm, region = region_y)
        num_Y = Y.shape[2]

        p_input = deepcopy(p_rrglm)
        p_input['window'] = [x - p_input['binsize'] for x in p_input['window']]
        Y_input = get_Y(df_cell, df_trial, p_input, region = region_x)
        data['rank_dict']['num_inputs'] = Y_input.shape[2]

        dataset = glmRRDataset(X, Y_input, Y)
        
        this_glm = glm_rr(
            covar_dict = data['covar_dict'],
            num_Y = num_Y,
            rank_dict = data['rank_dict'],
            dt = p_rrglm['dt']).to(device)

        this_state_dict = data['model_dict'][fit_type]['state_dict']

        for val, val_id in zip(model_variables, model_var_list):
            getattr(this_glm, val).weight.data = to_t(this_state_dict[val_id])
        this_glm.bias = nn.Parameter(to_t(this_state_dict['bias']))
        this_glm.rr_basis_w = nn.Parameter(to_t(this_state_dict['rr_basis_w']))
        
        # from the activity after being filtered
        _, output = this_glm.preds_rr(
            dataset.stimulus,
            dataset.input_spikes,
            dataset.spikes)
        output = from_t(output)
        
        # Now align the activity
        aligned_output = []
        for tnum in range(len(df_trial)):
            start_index = 1000*(df_trial.loc[tnum, p['align_to'][0]] + (p['start_time'][0]/1000) - df_trial.loc[tnum, p_rrglm['align_to']])/p_rrglm['binsize']
            start_index = int(np.floor(start_index))
            aligned_output.append(output[tnum, range(start_index, start_index + ntpts_per_trial[tnum])])
        aligned_output = np.array(flatten_list(aligned_output)).T
        
        # and then decode
        summary[fit_type] = run_evidence_decoding(aligned_output, evidence, ntpts_per_trial, p_rrglm)
        summary[fit_type]['rr_vec'] = this_state_dict['rr_U_w.weight']

        # np.save(fname, summary)



## load the fits up

In [None]:
optimal_rank_dict

In [None]:
EV_DECODING_DIR = SPEC.RESULTDIR + 'neural_rr_GLM/decoding/'
decoding_files = sorted([fn for fn in os.listdir(EV_DECODING_DIR) if '.npy' in fn and 'decoding' in fn and '28-05-2024' in fn])

decoding_results = dict()

for fit_type in list(optimal_rank_dict):
    region_x, region_y = fit_type.split('_')
    decoding_results[region_y] = []
    decoding_results[fit_type] = []
    this_decoding_files = sorted([fn for fn in decoding_files if fit_type in fn])
    
    for this_file in this_decoding_files:
        data = np.load(EV_DECODING_DIR + this_file, allow_pickle = True).item()
        decoding_results[region_y].append(data[region_y]['accuracy'])
        decoding_results[fit_type].append(data[fit_type]['accuracy'])


In [None]:
fig, axs = plt.subplots(1,2, figsize = (4,1.5), sharey = True, sharex = True)

xaxs = np.arange(0, 1.5, 0.005)

tstart = 0
tend = 165
for k, fit_type in enumerate(list(optimal_rank_dict)):
    region_x, region_y = fit_type.split('_')
    x = []
    for i in range(len(decoding_results[region_x])):
        x.append(np.clip(decoding_results[fit_type][i][tstart:tend]/decoding_results[region_x][i][tstart:tend],0,1))
    npts = np.sum(~np.isnan(x), axis =0)
    vals_mean = np.nanmean(x, axis = 0)
    vals_sem = np.nanstd(x, axis = 0)/np.sqrt(npts)
    
    axs[k].plot(xaxs[tstart:tend],
                vals_mean,
                c = SPEC.COLS[region_x],
                ls = '-')
    axs[k].fill_between(xaxs[tstart:tend],
                    vals_mean - vals_sem, 
                    vals_mean + vals_sem,
                    color = SPEC.COLS[region_x],
                    alpha = 0.3)
    print(np.mean(vals_mean))
    axs[k].set_ylim([-0.,0.85])
    axs[k].set_xticks(np.arange(0, 1.0, 0.2))
    axs[k].set_xlabel('Time from clicks on [s]')
    axs[k].set_title(f"{region_x} $\\rightarrow$ {region_y} vs {region_x}")

axs[0].set_ylabel('Decoding performance \n(CS/population)')

sns.despine()
savethisfig(SPEC.FIGUREDIR + "figure2/", 'figure2_CSdecoding')
