In [None]:
from torch.utils.data import random_split

from scripts.avi_rnn import *
from scripts.utils import set_seed
from scripts.ds_class import *
from scripts.homeos import *
from scripts.plotting import *
from scripts.fit_motif import *
from scripts.time_series import *
from scripts.ra import *
from scripts.exp_tools import *
exp_folder = "code/experiments/avi"

In [None]:
random_seed = 42
np.random.seed(random_seed)
T=25.6/2
dt=.1 
batch_size=50

In [None]:
# load RNN and generate trajectories
task = angularintegration_task_constant(T=T, dt=dt, speed_range=[0.,0.], sparsity=1, random_angle_init='equally_spaced');
for N in [64,128,256]:
    sub_exp = f"N{N}_T128_noisy"
    for activation in ["recttanh"]: #, "relu", "tanh"]:
        folder = os.path.join(exp_folder, sub_exp, activation)
        exp_list = glob.glob(os.path.join(folder, "res*"))
        nact_exp = exp_folder +  "/all_trajs" + f"/N{N}_{activation}"
        print(f"Processing {nact_exp}")
        for exp_i in range(len(exp_list)):
            path = exp_list[exp_i]
            try:
                net, result = load_net_path(path)
            except:
                print(f"Error loading {path}")
                continue
            net.eval()    
            input, target, mask, output, trajectories = simulate_rnn_with_task(net, task, T, '', batch_size=batch_size)
            net_id = os.path.basename(path).split("_")[-1].split(".")[0]
            print(net_id, trajectories.shape)
            os.makedirs(nact_exp, exist_ok=True)
            np.save(f'{nact_exp}/trajectories_{net_id}.npy', trajectories.copy())


In [None]:
archetypes_2d = ['lds', 'lc', 'ring', 'bla', 'bistable']

In [None]:
N = 64
activation = "recttanh"
exp_dir = Path('experiments')
data_dir = exp_dir  / 'avi' / "all_trajs" / f"N{N}_{activation}"

npy_files = list(data_dir.glob('*.npy'))
print(npy_files)
file = npy_files[0]
target_name = file.name.removesuffix('.npy')
save_dir = Path('experiments_test') / 'avi' / "motif_fits" / f"N{N}_{activation}"
print(target_name)

In [None]:
#run on last RNN only
for archetype in archetypes_2d:
    print(f"Running archetype fit for {archetype}")
    run_on_target(target_name, save_dir=save_dir, data_dir=data_dir, ds_motif=archetype, analytic=True, canonical=True, jac_lambda_reg=.0, num_epochs=200, quick_jac=True)

In [None]:
## run ring on N128, N256
for N in [128, 256]:
    activation = "recttanh"
    exp_dir = Path('experiments')
    data_dir = exp_dir  / 'avi' / "all_trajs" / f"N{N}_{activation}"

    npy_files = list(data_dir.glob('*.npy'))
    file = npy_files[0]
    target_name = file.name.removesuffix('.npy')
    save_dir = Path('experiments_test') / 'avi' / "motif_fits" / f"N{N}_{activation}"
    run_on_target(target_name, save_dir=save_dir, data_dir=data_dir, ds_motif='ring', analytic=True, canonical=True, jac_lambda_reg=.0, num_epochs=200, quick_jac=True)