In [None]:
import re, sys
sys.path.append('/home/dikshag/pbups_ephys_analysis/pbups_phys/multiregion_RNN/')
from multiregion_RNN_utils import *


import seaborn as sns
sns.set_style("white")
sns.set_theme(context='talk', 
              style='ticks',  
              font='Helvetica', 
              font_scale=1.3,  
              rc={"axes.titlesize": 13})
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['axes.unicode_minus'] = False


base_path = get_base_path()
figure_dir = get_figure_dir()

In [2]:
fitfile = "run_optomul2023_1d_30_30_5"

In [3]:
files = sorted([fn for fn in os.listdir(base_path) if re.findall(fitfile, fn) and not fn.endswith(".npy")])
print(len(files))

10


## plot fit psychometrics

In [None]:
fig_perf, ax_perf = plt.subplots(figsize=(5, 5))

# reinitialize to plot model performance
new_params = {'N_batch': 1000,
              'p_probe': 0.,
              'frac_opto': 0.,
              'probe_duration': 0}

for file_num, file in enumerate(files):
    
    print("\n\t" + file)
    FOF_ADS, pc_data, _ = reinitialize_network(file, new_params)

    # simulate stimulus
    if file_num == 0:
        x, y, m, params = pc_data.get_trial_batch()
    
    output, activity = FOF_ADS.test(x)
    df_trial, activity, data_dict = format_data(x,y,m, params, output, activity)       
    plot_psych(df_trial, 'choice', ax_perf, color = 'grey')
    FOF_ADS.destruct()
    
plot_psych(df_trial, 'choice_target_end', ax_perf, legend = 'Target', color = 'brown')
savethisfig(figure_dir, 'psych_RNNfits')

## plot fit opto effects + biFOF effects

In [None]:
def process_inactivation_summary(summary, perturb_grps, gains = None):
    
    epochs = ['first_half', 'second_half']
    if gains is None:
        gains = list(summary[list(summary)[0]][epochs[0]])
    data = {this_grp:{this_gain:{epoch:[] for epoch in epochs} for this_gain in gains} for this_grp in perturb_grps}
    
    for this_grp in perturb_grps:
        for this_gain in gains:
            for epoch in epochs:
                if 'bi' in this_grp:
                    temp = summary[this_grp][epoch][this_gain]['accuracy']
                else:
                    temp = []
                    for side in ['left', 'right']:
                        temp.append(summary[side + '_' + this_grp][epoch][this_gain]['bias'])
                data[this_grp][this_gain][epoch] = np.ravel(temp)     
    return data
 

In [None]:
query_filename = fitfile + "_inactivation_train_summary_"
summary_file = sorted([fn for fn in os.listdir(base_path) if re.findall(query_filename, fn) and fn.endswith(".npy")])
summary = np.load(base_path + summary_file[0], allow_pickle = True).item()


# # unilateral inactivations
data = process_inactivation_summary(summary, ['FOF', 'proj', 'ADS'], gains = [0.1])
plot_inactivation_summary(data)
savethisfig(figure_dir, 'unilateral_RNN_inactivations')

# bilateral FOF inactivation
data = process_inactivation_summary(summary, ['bi_FOF'], gains = [0.1])
plot_inactivation_summary(data)
savethisfig(figure_dir, 'biFOF_RNN_inactivations')


## Save weight matrix

In [None]:
sns.set_theme(context='talk', style='ticks',  font='Helvetica', font_scale=0.5,  rc={"axes.titlesize": 13})

dirname = base_path + files[0] + os.sep
x = np.load(dirname + 'final_weights.npz', allow_pickle = True)
fig,axs = plt.subplots(figsize = (5,5))
plot_weights(x['W_rec'], ax = axs)
savethisfig(figure_dir, "weight_matrix")

sns.set_theme(context='talk', style='ticks',  font='Helvetica', font_scale=1.25,  rc={"axes.titlesize": 13})


### Plot results from decoding analysis

In [None]:
# load the saved summary
query_filename = fitfile + "_choice_stim_decoding_"
summary_file = sorted([fn for fn in os.listdir(base_path) if re.findall(query_filename, fn) and fn.endswith(".npy")])
summary = np.load(base_path + summary_file[0], allow_pickle = True).item()

# get dt for setting the x-axis in decoding plots
dirname = base_path + files[0] + os.sep
network_params = np.load(dirname + 'network_params.npy', allow_pickle = True).item()
dt = network_params['dt']

# reformatting the saved summary for plotting
variables = ['stim', 'choice']
regions = ['FOF', 'ADS']
dec = {reg: {var: [] for var in variables} for reg in regions}
for var in variables:
    for reg in regions:
        for file in files:
            dec[reg][var].append(summary[var][reg][file]['accuracy'])
        dec[reg][var] = np.array(list(itertools.zip_longest(*dec[reg][var],
                                                                fillvalue=np.nan))).T
        
title_dict = {'stim': 'Stimulus decoding',
              'choice': 'Choice decoding'}


cols =  {'FOF': [0.0, 0.6431372549019608, 0.8],
         'ADS': [0.9137254901960784, 0.45098039215686275, 0.5529411764705883]}

# okay plotting
fig, axs = plt.subplots(1, len(variables), figsize = (3*len(variables), 3))
for ax, var in zip(axs.ravel(), variables):    
    for reg in regions:
        t = 0.001*np.arange(np.shape(dec[reg][var])[1])*dt
        npts = np.sum(~np.isnan(dec[reg][var]), axis =0)
        sem = np.nanstd(dec[reg][var], axis = 0)/np.sqrt(npts)
        mean = np.nanmean(dec[reg][var], axis = 0)
        ax.plot(t, mean, c = cols[reg], label = "model " + reg)
        ax.fill_between(t, mean - sem, mean + sem, color = cols[reg], alpha = 0.3)
        
    ax.set_xlabel('Time from clicks on [s]')
    ax.set_title(title_dict[var])
    ax.set_ylabel('corr(#R-#L, predicted)' if var == 'stim' else 'Choice decoding accuracy')

   
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, frameon = False)
plt.tight_layout()
sns.despine()

savethisfig(figure_dir, 'stim_choice_decoding')

### Plot some PSTHs

In [None]:
new_params = {'N_batch': 1000, 
              'p_probe': 0., 
              'probe_duration': 0}


FOF_ADS, pc_data, _ = reinitialize_network(files[1], new_params)
x,y,m,params = pc_data.get_trial_batch()
output, activity = FOF_ADS.test(x)
df_trial, activity, _ = format_data(x,y,m,params, output, activity)

# pass the activity through the nonlinearity
with tf.compat.v1.Session() as sess:
    firing_rate = sess.run(FOF_ADS.transfer_function(activity))
    sess.close()

# np.random.seed(30)
# np.random.randint(0,300,4)
cell_idx = [176, 262, 32, 140]
make_RNN_PSTHs(df_trial, 
               firing_rate[cell_idx,:,:], 
               FOF_ADS.dt, 
               split_by = 'difficulty',
               align_to = 'stim_onset',
               num_per_col = 2)
savethisfig(figure_dir, 'PSTHs')