In [1]:
import sys
sys.path.append("../data/saved_models/")
sys.path.append("../model_scripts/")
sys.path.append("../utils/")
import os
import json

import numpy as np
import torch

import matplotlib.pyplot as plt
import fig3_plots

import fig3_analysis as rnn
import basic_analysis as basic
import model_utils
from task import generate_batch

import scipy
from scipy import stats
from sklearn.decomposition import PCA
from scipy.ndimage import gaussian_filter1d
from scipy.spatial import distance as dist

In [6]:
# file paths
data_folder = f"../data/saved_models/2d_2map/"
save_folder = f"../figures/fig3_plots/"

if os.path.isdir(save_folder):
    print('save folder exists')
else:
    os.mkdir(save_folder)
    
# font sizes
title_size = 10

save folder exists


In [7]:
# get the model IDs for all saved models
model_IDs = os.listdir(data_folder)

# select example model
ex_id = 0
model_ID = model_IDs[ex_id]

In [9]:
# get sample rnn data
inputs, outputs, targets = model_utils.sample_rnn_data(data_folder, model_ID)
X, map_targ, pos_targ = model_utils.format_rnn_data(outputs["hidden_states"],\
                                                    targets["map_targets"],\
                                                    targets["pos_targets"])

In [None]:
''' Figure 3A: model schematic '''
f, gs = fig3_plots.plot_a1(inputs["inp_vel"],\
                          inputs["inp_remaps"])
plt.show()
f.savefig(f'{save_folder}inputs.png', dpi=600, bbox_inches='tight')

f, gs = fig3_plots.plot_a2(targets["pos_targets"],\
                          outputs["pos_outputs"],\
                          outputs["map_logits"])
plt.show()
f.savefig(f'{save_folder}outputs.png', dpi=600, bbox_inches='tight')

In [None]:
''' summary performance across models '''
f, ax = fig3_plots.plot_supp_1(data_folder, model_IDs)
ax.set_title('2D pos\n2 maps', fontsize=title_size, pad=5)
plt.show()

In [None]:
''' Figure 3B: example tuning '''
f, axes = fig3_plots.plot_b(X, map_targ, pos_targ)
plt.show()
f.savefig(f'{save_folder}ex_tuning.png', dpi=600, bbox_inches='tight')

In [None]:
''' Figure 3C: aligned toroidal manifolds'''
f, axes = fig3_plots.plot_c(X, pos_targ, map_targ, 
                            num_points=4000,
                            color_x=True)
plt.show()
f.savefig(f'{save_folder}manifolds_xpos.png', dpi=600, bbox_inches='tight')

f, axes = fig3_plots.plot_c(X, pos_targ, map_targ, 
                            num_points=4000,
                            color_x=False)
plt.show()
f.savefig(f'{save_folder}manifolds_ypos.png', dpi=600, bbox_inches='tight')

In [None]:
''' Figure 3D: summary of dimensionality '''
f, axes = fig3_plots.plot_d(data_folder, model_IDs,
                            top_num=4, top_num_1=3)
plt.show()
f.savefig(f'{save_folder}PCs_all.png', dpi=600, bbox_inches='tight')

In [None]:
''' Figure 3E: summary of torus alignment '''
f, ax = fig3_plots.plot_e(data_folder, model_IDs)
f.savefig(f'{save_folder}alignment.png', dpi=600, bbox_inches='tight')
plt.show()

In [None]:
''' Figure 3F: slices showing aligned rings 
subsample holding x or y constant to get a slice in each direction and plot the aligned rings
'''
x_pos = pos_targ[:, 0]
y_pos = pos_targ[:, 1]

# slice from a fixed Y position
y_idx = (np.round(pos_targ[:, 1], 2) == 0)
f, ax = fig3_plots.plot_f(X[y_idx], pos_targ[y_idx])
plt.show()
f.savefig(f'{save_folder}slices_xpos.png', dpi=600, bbox_inches='tight')

# slice from a fixed X position
x_idx = (np.round(pos_targ[:, 0], 2) == 0)
f, ax = fig3_plots.plot_f(X[x_idx], pos_targ[x_idx])
plt.show()
f.savefig(f'{save_folder}slices_ypos.png', dpi=600, bbox_inches='tight')

In [None]:
''' Figure 3g: alignment to remapping dim and position subspace '''
f, ax = fig3_plots.plot_g(data_folder, model_IDs)
plt.show()
f.savefig(f'{save_folder}dim_angles.png', dpi=600, bbox_inches='tight')