In [1]:
# Analysis for Fig 6d-e, Fig S7

import numpy as np
import matplotlib.pyplot as plt

import os
import sys
import pickle
import copy
from scipy.io import savemat
from scipy.linalg import schur
from sklearn.cluster import DBSCAN as dbscan
#from dynamics.process.rnn.parse import KEconvert2matlab
from scipy.stats import ranksums, wilcoxon,ks_2samp
import glob
import warnings
from sklearn.metrics import pairwise_distances
from os.path import exists
from scipy import interpolate
from copy import deepcopy

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import json



sys.path.insert(0, '/Users/dhocker/projects/kind_cl/')

from dynamics.process.rnn import wt_kindergarten, wt_nets, wt_costs, wt_reinforce_cont_new, wt_pred, parse, parse_state
%load_ext autoreload
%autoreload 2

warnings.filterwarnings("ignore", message="Trying to unpickle estimator PCA*")


In [2]:
def KEconvert2matlab(fname, dosave=True, eps=3, doadaptive=False, savedir=None):
    """
    extracts the fixed and slow points from KE minization file, clusters with DBSCAN,
    then converts to a matlab file
    :param fname: file name of kinetic energy minima located by minimization
    :param dosave: if true, saves .mat file. otherwise just returns dict
    :param eps: the DBscan parameter
    :param doadaptive: should the epsilon parameter be chosen adaptive for each network, based on PC1 range?
    :param savedir: if None, save in original data dir. this helps for trying different hyperparams
    :return: (dict) of coordinates, schur modes, eigenvalues, etc.
    """

    if not exists(fname):
        #print('data file does not exist')
        return None

    if fname.split('.')[-1] == 'dat':
        try:
        #print('loading legacy .dat data')
            with open(fname, 'rb') as f:
                dat = pickle.load(f)
        except EOFError:
            print('eof error:'+fname)
            return None

    # format everything into a matlab-friendly thing
    ns = len(dat['output'])
    #crds = np.array([[dat['output'][k][0][0], dat['output'][k][0][1]] for k in range(ns)])  # assume 2d
    crds = np.array([dat['output'][k][0] for k in range(ns)])
    KE = np.array([dat['output'][k][1] for k in range(ns)])
    jac_pc = np.array([dat['output'][k][2][0] for k in range(ns)])
    evals_full = np.array([dat['output'][k][2][1] for k in range(ns)])
    D = crds.shape[1]
    
    

    # filter the data into slow and fast points, then cluster to get the effecgtive points
    # create cutoffs for fixed points and slow points. filter data
    mask_fixed = KE < 0.0001

    # define slow points that only change magnitude < 5% across duration of trial
    lam = 2.5
    dt = 0.05
    eps_norm = 0.05
    eps_fp = 1e-3
    sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))
    mask_slow = sdiffnorm < eps_norm

    mask = (mask_fixed) | (mask_slow)
    # potentially get ride of this
    mask = np.ones(mask.shape).astype(bool)

    crds_masked = crds[mask, :]
    KE_masked = KE[mask]
    jac_pc_masked = jac_pc[mask, :, :]
    evals_full_masked = evals_full[mask, :]
    sdiffnorm_masked = sdiffnorm[mask]
    isfixed_masked = KE_masked < 0.0001
    ns_masked = sum(mask)
    
    
    #decide on hyperparameter
    if doadaptive:
        xvals = [k[0][0] for k in dat['crds']]
        xrange = max(xvals)-min(xvals)
        eps = np.abs(xrange*eps)
        if eps==0:
            print('issue with range')
            return None
        
    
    #find number of dynamical systems features
    clustering = dbscan(eps=eps, min_samples=10).fit(crds_masked)
    labs = np.unique(clustering.labels_)
    labels = clustering.labels_
    
    #analyze the spread of each feature. relevant only for line attractors, but that analysis is later
    featurelen = []
    for k in labs:
        pts = crds_masked[labels==k]
        featurelen.append(np.max(pairwise_distances(pts)))
    
    fetdict = dict(zip(labs,featurelen))
    
    # TODO: determine if each labeled feature is in range. give it a bit of a buffer of 30%
    min_pcvals = []
    max_pcvals = []
    for m in range(D):
        min_pcvals.append(1.3*np.min(np.array([k[0][m] for k in dat['crds']])))
        max_pcvals.append(1.3*np.max(np.array([k[0][m] for k in dat['crds']])))
        
    inrange = []  # is each masked fixed/slow point within bounds of PC_min and PC_max?    
    for k in crds_masked:
        cond_all = True
        for m in range(D):
            cond_m = k[m] > min_pcvals[m] and k[m] < max_pcvals[m]  # if out of range for any PC
            cond_all = cond_all and cond_m
        inrange.append(cond_all)
            
    
    
    

    savedict = {'ns': ns_masked, 'crds': crds_masked, 'KE': KE_masked, 'jac_pc': jac_pc_masked,
                'evals_full_masked': evals_full, 'statediff': sdiffnorm_masked, 'labels': labels,
                'isfixed': isfixed_masked, 'fetdict':fetdict, 'D':D, 'inrange':inrange}
    if dosave:
        if savedir is None:
            savemat(fname.split('.')[0] + '.mat', savedict)
        else:
            savename = savedir + fname.split('/')[-1].split('.')[0]+'.mat'
            savemat(savename, savedict)
    return savedict

In [3]:


ttype = 'pkind_mem'
epoch = 'wait'
reg_idx = 1

def getdirname(ttype,epoch, num=None, reg_idx = 0, timedep=False):

    if not timedep:
        if ttype == 'full_cl_old':
            #dirname = '/Users/dhocker/projects/dynamics/results/rnn/ac/20230206_fullclstudy/full_cl_redo/dynamics/'  
            dirname = '/scratch/dh148/dynamics/results/rnn/ac/20230206_clstudy/full_cl_redo/dynamics/KEmin_constrained/' 
            if epoch == 'wait':
                txt = 'kemin_rnn_curric_*block_10reg_'+str(reg_idx)+'_D2_wait.dat'
            elif epoch == 'iti':
                txt = 'kemin_rnn_curric_*block_10reg_'+str(reg_idx)+'_D2_iti.dat'
            elif epoch == 'start':
                txt = 'kemin_rnn_curric_*block_10reg_'+str(reg_idx)+'_D2_start.dat'
            dirnames = glob.glob(dirname+txt)
            tphase = 5*np.ones((len(dirnames)))
        elif ttype == 'full_cl':
            #dirname = '/Users/dhocker/projects/dynamics/results/rnn/ac/20230206_fullclstudy/full_cl_redo/dynamics/'  
            #dirname = '/scratch/dh148/dynamics/results/rnn/ac/20230206_clstudy/full_cl_redo/dynamics/KEmin_constrained/' 
            dirname = '/scratch/dh148/dynamics/results/rnn/ac/20231003/full_cl/dynamics/KEmin_constrained/' 
            if epoch == 'wait':
                txt = 'kemin_rnn_curric_*block_10reg_'+str(reg_idx)+'_mixed_wait.dat'
            elif epoch == 'iti':
                txt = 'kemin_rnn_curric_*block_10reg_'+str(reg_idx)+'_mixed_iti.dat'
            elif epoch == 'start':
                txt = 'kemin_rnn_curric_*block_10reg_'+str(reg_idx)+'_mixed_start.dat'
            dirnames = glob.glob(dirname+txt)
            tphase = 5*np.ones((len(dirnames)))
        elif ttype == 'nok_cl':
            #dirname = '/Users/dhocker/projects/dynamics/results/rnn/ac/20230206_fullclstudy/nok_cl/dynamics/'
            dirname = '/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/' 
            if epoch == 'wait':
                txt = 'kemin_rnn_curric_*block_60reg_'+str(reg_idx)+'_mixed_wait.dat'
            elif epoch == 'iti':
                txt = 'kemin_rnn_curric_*block_60reg_'+str(reg_idx)+'_mixed_iti.dat'
            elif epoch == 'start':
                txt = 'kemin_rnn_curric_*block_60reg_'+str(reg_idx)+'_mixed_start.dat'
            dirnames = glob.glob(dirname+txt)
            tphase = 5*np.ones((len(dirnames)))
        elif ttype == 'pkind_mem':
            #dirname = '/Users/dhocker/projects/dynamics/results/rnn/ac/20230206_fullclstudy/nok_cl/dynamics/'
            dirname = '/scratch/dh148/dynamics/results/rnn/ac/20231003/pkind_mem/dynamics/KEmin_constrained/' 
            if epoch == 'wait':
                txt = 'kemin_rnn_curric_*block_20reg_'+str(reg_idx)+'_mixed_wait.dat'
            elif epoch == 'iti':
                txt = 'kemin_rnn_curric_*block_20reg_'+str(reg_idx)+'_mixed_iti.dat'
            elif epoch == 'start':
                txt = 'kemin_rnn_curric_*block_20reg_'+str(reg_idx)+'_mixed_start.dat'
            dirnames = glob.glob(dirname+txt)
            tphase = 5*np.ones((len(dirnames)))

    else:
        
        assert num is not None, "num must be supplied for timedep=True option. single RNNs only"
        
        if ttype == 'full_cl':
            idxmax_block = 10
        elif ttype == 'nok_cl':
            idxmax_block = 20
        elif ttype == 'pkind_mem':
            idxmax_block = 20
        else:
            print('issue with choosing ttype')

            
        #dirname = '/scratch/dh148/dynamics/results/rnn/ac/20230206_clstudy/full_cl_redo/dynamics/KEmin_constrained/'       
        dirname = '/scratch/dh148/dynamics/results/rnn/ac/20231003/'+ttype+'/dynamics/KEmin_constrained/'     
        dirnames = []
        tphase = []


        
        # mse kindergarten
        if ttype == 'full_cl':
            txt = 'kemin_rnn_kindergarten_'+str(num)+'_simplereg_'+str(reg_idx)+'_kind_'+epoch+'.dat'
            dirnames.append(dirname+txt)
            tphase.append(0)
            
            for k in range(1,11):
                txt = 'kemin_rnn_kindergarten_'+str(num)+'_int_0_'+str(k)+'reg_'+str(reg_idx)+'_kind_'+epoch+'.dat'
                dirnames.append(dirname+txt)
            tphase.extend(1*np.ones((10,)))
            
        if ttype == 'pkind_mem':
            
            for k in range(1,11):
                txt = 'kemin_rnn_kindergarten_'+str(num)+'_int_0_'+str(k)+'reg_'+str(reg_idx)+'_mixed_'+epoch+'.dat'
                dirnames.append(dirname+txt)
            tphase.extend(1*np.ones((10,)))
        
        
        # inference
        if ttype == 'full_cl':
            for k in range(1,101,1):
                txt = 'kemin_rnn_pred_'+str(num)+'_'+str(k)+'reg_'+str(reg_idx)+'_kind_'+epoch+'.dat'
                dirnames.append(dirname+txt)
            tphase.extend(2*np.ones((100,)))

        # task
        for k in range(1,11):
            txt = 'kemin_rnn_curric_'+str(num)+'_nocatch_'+str(k)+'reg_'+str(reg_idx)+'_mixed_'+epoch+'.dat'
            dirnames.append(dirname+txt)
        tphase.extend(3*np.ones((10,)))
        for k in range(1,11):
            txt = 'kemin_rnn_curric_'+str(num)+'_catch_'+str(k)+'reg_'+str(reg_idx)+'_mixed_'+epoch+'.dat'
            dirnames.append(dirname+txt)
        tphase.extend(4*np.ones((10,)))
        for k in range(1,idxmax_block+1):
            txt = 'kemin_rnn_curric_'+str(num)+'_block_'+str(k)+'reg_'+str(reg_idx)+'_mixed_'+epoch+'.dat'
            dirnames.append(dirname+txt)
        tphase.extend(5*np.ones((idxmax_block,)))

    
    
    return dirnames, tphase
dirnames, tphase = getdirname(ttype,epoch,timedep=True, num=10)
dirnames[-1]


'/scratch/dh148/dynamics/results/rnn/ac/20231003/pkind_mem/dynamics/KEmin_constrained/kemin_rnn_curric_10_block_20reg_0_mixed_wait.dat'

In [5]:
def classify_gen(eigs,D=2):
    """will classify eigenvalues in arbitraty dim as 
      - attractor (stable) both < 1
      - source (unstable) both > 1
      - saddle (semistable eig_1 < 1, eig2 > 1)
      - line attractor (one eig is unity, within a tolerance)
      
      (handle these later. need to check textbook)
      - stablespiral (complex eigenvalues), norm < 1
      - unstable spiral (complex)
      - orbit (purely imaginary)"""
    
    tol1_low = 0.999
    tol1_high = 1.001
    
    ltmask = eigs < tol1_low # attractors
    gtmask = eigs > tol1_high # unstable
    linemask = (eigs > tol1_low) & (eigs < tol1_high) #lines
    
    if np.sum(ltmask) == D:
        return 'attractor'
    elif np.sum(gtmask) == D:
        return 'source'
    elif np.sum(linemask) > 0:
        return 'line'
    elif np.sum(ltmask) > 0 and np.sum(gtmask) > 0:
        return 'saddle'
    else:
        print('i missed something'+ str(eigs))
        return None
    
    

In [6]:
def getnpts(dirnames,tphase, subtype='fixed', eps=3, doadaptive=False, savedir=None, restrictrange=False):
    """helper script to get number and type of fixed points
        if savedir is None: defautl to saving things in datadir. useful for different eps runs
        restrictrange is for excluding feature outside range of PC activity """
    nfixed_all = []
    nslow_all = []
    ntot = []
    
    nattractor_all = []
    nsaddle_all = []
    nline_all = []
    nsource_all = []
    ctr = 0
    cutctr = 0
    
    dirnamecpy = copy.deepcopy(dirnames)
    tphasecpy = copy.deepcopy(tphase)
    
    fetlen_all = []
    fetlentype_all = []
    inrange_all = []
    Dall = []
    
    for m in dirnames:
        #try:
        if savedir is None:
            dosave = False
        else:
            dosave = True

        dat = KEconvert2matlab(m,dosave = dosave, eps=eps, doadaptive=doadaptive, savedir=savedir)
        #except:
        #    print('issue with file:'+m)
        #    dat = None
        
        if dat is None:
            #print('didnt exist: '+m)        
            nfixed = np.nan
            nslow = np.nan
            nattractor = np.nan
            nsaddle = np.nan
            nline = np.nan
            nsource = np.nan
            npts = []
            D = np.nan
            #maximum length of feature, and its type
            fetlen = []
            fetlentype = []

        else:
            D = dat['D']
            
            if restrictrange:
                mask = dat['inrange']
            else:
                mask = np.ones(len(dat['labels'])).astype(bool)            
                
            npts = np.unique(dat['labels'][np.array(dat['labels']>-1) & np.array(mask)])

            nfixed = 0
            nslow = 0
            nattractor = 0
            nsaddle = 0
            nline = 0
            nsource = 0
            
            #maximum length of feature, and its type
            fetlen = []
            fetlentype = []
            
            #also track if feature is is in the range of network activity
            inrange = []


            for k in npts:
                pt_k_idx = np.argwhere(dat['labels']==k)[0,0]
                
                fetlen.append(dat['fetdict'][k])
                inrange.append(dat['inrange'][k])
                
                
                if dat['isfixed'][pt_k_idx]:
                    nfixed+=1
                else:
                    nslow+=1

                test = schur(dat['jac_pc'][pt_k_idx], output='real')
                evals = np.diag(test[0])
                #c = classify(evals)
                c = classify_gen(evals,D)
                #types.append(c)
                
                #add points only if following condition is met
                cond1 = dat['isfixed'][pt_k_idx] and subtype=='fixed' #update only if fixed
                cond2 = not dat['isfixed'][pt_k_idx] and subtype=='slow' #update only if slow
                cond3 = subtype == 'all' #update all of them

                #only update if fixed
                if cond1 or cond2 or cond3:
                    if c == 'attractor':
                        nattractor +=1
                        fetlentype.append('attractor')
                    elif c == 'saddle':
                        nsaddle +=1
                        fetlentype.append('saddle')
                    elif c == 'line':
                        nline +=1
                        fetlentype.append('line')
                    elif c == 'source':
                        nsource +=1     
                        fetlentype.append('source')
                    else:
                        print('none chosen')
                        

        nfixed_all.append(nfixed)
        nslow_all.append(nslow)
        ntot.append(len(npts))
        nattractor_all.append(nattractor)
        nsaddle_all.append(nsaddle)
        nline_all.append(nline)
        nsource_all.append(nsource)
        fetlen_all.append(fetlen)
        fetlentype_all.append(fetlentype)
        Dall.append(D)
        inrange_all.append(inrange)
        

        
    typelist = [nattractor_all, nsaddle_all, nline_all, nsource_all]

    
    return nfixed_all, nslow_all, ntot, typelist, dirnamecpy, tphasecpy, fetlen_all, fetlentype_all, Dall, inrange_all

    

In [16]:
#Fig. E. the end-of-training cumulative plot

st = 'all'
vers = 'full_cl'  #

#reg_idx = 0 # for OFC results
reg_idx = 1  # for STR results


savedir = None


eps = 0.01  # ratio of PC1 range
doadaptive = True
restrictrange = False

dirnames, tphase = getdirname(vers,'wait', reg_idx=reg_idx)
npts_kind_wait = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir, restrictrange=restrictrange)

dirnames, tphase = getdirname(vers,'iti', reg_idx=reg_idx)
npts_kind_iti = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)

dirnames, tphase = getdirname(vers,'start', reg_idx=reg_idx)
npts_kind_start = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)


dirnames, tphase = getdirname('nok_cl','wait', reg_idx=reg_idx)
npts_classic_wait = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)

dirnames, tphase = getdirname('nok_cl','iti', reg_idx=reg_idx)
npts_classic_iti = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)

dirnames, tphase = getdirname('nok_cl','start', reg_idx=reg_idx)
npts_classic_start = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)


dirnames, tphase = getdirname('pkind_mem','wait', reg_idx=reg_idx)
npts_mem_wait = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)

dirnames, tphase = getdirname('pkind_mem','iti', reg_idx=reg_idx)
npts_mem_iti = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)

dirnames, tphase = getdirname('pkind_mem','start', reg_idx=reg_idx)
npts_mem_start = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive,savedir=savedir,restrictrange=restrictrange)




In [17]:
# want to save intermediate files?
savedict = {'npts_kind_wait':npts_kind_wait, 'npts_kind_iti':npts_kind_iti, 'npts_kind_start':npts_kind_start,
            'npts_classic_wait':npts_classic_wait, 'npts_classic_iti':npts_classic_iti, 'npts_classic_start':npts_classic_start,
            'npts_mem_wait':npts_mem_wait, 'npts_mem_iti':npts_mem_iti, 'npts_mem_start':npts_mem_start}

dbase = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
with open(dbase+'ke_results_str_wait_all3.dat','wb') as f:
    test = pickle.dump(savedict,f)

In [52]:
# try the pop average
npts_all = []
npts_all_str = []

dims_all = []
dims_all_str = []

#userange = range(1,21)
#userange = [2,3,4,5,6,7,8,9,10,11,13,14,15,16,18,19]  #until all networks are done
userange = range(1,51)

eps = 0.01
doadaptive = True
timedep = True


for j in userange:
    print(j)
    st = 'all'
    ttype = 'full_cl'
    epoch = 'wait'
    num = j
    reg_idx = 0
    dirnames, tphase = getdirname(ttype,epoch,num,reg_idx,timedep=timedep)
    npts_kind_wait = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive)
    npts_j = np.array(npts_kind_wait[0])+np.array(npts_kind_wait[1])
    npts_all.append(npts_j)
    dims_all.append(npts_kind_wait[-2])
    
for j in userange:
    print(j)
    st = 'all'
    ttype = 'full_cl'
    epoch = 'wait'
    num = j
    reg_idx = 1
    dirnames, tphase_str = getdirname(ttype,epoch,num,reg_idx,timedep=timedep)
    npts_kind_wait = getnpts(dirnames,tphase,st, eps=eps, doadaptive=doadaptive)
    npts_j = np.array(npts_kind_wait[0])+np.array(npts_kind_wait[1])
    npts_all_str.append(npts_j)
    dims_all_str.append(npts_kind_wait[-2])
    


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
1
2
issue with range
issue with range
issue with range
issue with range
issue with range
issue with range
issue with range
issue with range
issue with range
3
issue with range
issue with range
4
5
6
7
issue with range
issue with range
issue with range
issue with range
issue with range
issue with range
8
9
issue with range
issue with range
issue with range
issue with range
issue with range
issue with range
10
11
issue with range
12
13
issue with range
issue with range
14
15
issue with range
16
17
issue with range
issue with range
18
19
issue with range
issue with range
issue with range
20
21
issue with range
issue with range
issue with range
22
23
issue with range
24
issue with range
issue with range
issue with range
25
26
27
issue with range
28
issue with range
issue with range
29
30
issue with range
31
32
33
34
35
36
issue with range
37
38
39
40
issue with range
issue with range
issue with range
41
42
43
issue with range
issu

In [53]:
dosave = True
doload = False

if dosave:

    savedir = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
    savename = savedir + 'ke_overtraining.dat'

    with open(savename,'wb') as f:
        pickle.dump([npts_all,npts_all_str,dims_all, dims_all_str,tphase, tphase_str],f)
        
if doload:
    savedir = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
    savename = savedir + 'ke_overtraining.dat'

    with open(savename,'rb') as f:
        [npts_all,npts_all_str,dims_all, dims_all_str,tphase, tphase_str] = pickle.load(f)

In [19]:
# do it for shaping only and memory

# try the pop average
npts_all_shp = []
npts_all_shp_str = []

dims_all_shp = []
dims_all_shp_str = []

#userange = range(1,21)
#userange = [2,3,4,5,6,7,8,9,10,11,13,14,15,16,18,19]  #until all networks are done
userange = range(1,21)

eps = 0.01
doadaptive = True
timedep = True


for j in userange:
    print(j)
    st = 'all'
    ttype = 'nok_cl'
    epoch = 'wait'
    num = j
    reg_idx = 0
    dirnames_shp, tphase_shp = getdirname(ttype,epoch,num,reg_idx,timedep=timedep)
    npts_kind_wait = getnpts(dirnames_shp,tphase_shp,st, eps=eps, doadaptive=doadaptive)
    npts_j = np.array(npts_kind_wait[0])+np.array(npts_kind_wait[1])
    npts_all_shp.append(npts_j)
    dims_all_shp.append(npts_kind_wait[-2])

for j in userange:
    print(j)
    st = 'all'
    ttype = 'nok_cl'
    epoch = 'wait'
    num = j
    reg_idx = 1
    dirnames_shp, tphase_shp_str = getdirname(ttype,epoch,num,reg_idx,timedep=timedep)
    npts_kind_wait = getnpts(dirnames_shp,tphase_shp_str,st, eps=eps, doadaptive=doadaptive)
    npts_j = np.array(npts_kind_wait[0])+np.array(npts_kind_wait[1])
    npts_all_shp_str.append(npts_j)
    dims_all_shp_str.append(npts_kind_wait[-2])


1
2
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_2_block_14reg_0_mixed_wait.dat
3
4
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_4_block_6reg_0_mixed_wait.dat
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_4_block_15reg_0_mixed_wait.dat
5
6
7
8
9
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_9_block_4reg_0_mixed_wait.dat
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_9_block_14reg_0_mixed_wait.dat
10
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_10_block_8reg_0_mixed_wait.dat
11
12
13
14
15
16
eof error:/scratch/dh148/dynamics/results/rnn/ac/20231003/nok_cl/dynamics/KEmin_constrained/kemin_rnn_curric_16_block_16reg

In [10]:
dosave = False
doload = True

if dosave:

    savedir = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
    savename = savedir + 'ke_overtraining_shp.dat'

    with open(savename,'wb') as f:
        pickle.dump([npts_all_shp,npts_all_shp_str,dims_all_shp, dims_all_shp_str,tphase_shp, tphase_shp_str],f)
        
if doload:
    savedir = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
    savename = savedir + 'ke_overtraining_shp.dat'

    with open(savename,'rb') as f:
        [npts_all_shp,npts_all_shp_str,dims_all_shp, dims_all_shp_str,tphase_shp, tphase_shp_str] = pickle.load(f)

In [21]:
# memory
# do it for shaping only and memory

# try the pop average
npts_all_mem = []
npts_all_mem_str = []

dims_all_mem = []
dims_all_mem_str = []

userange = range(1,21)

eps = 0.01
doadaptive = True
timedep = True


for j in userange:
    print(j)
    st = 'all'
    ttype = 'pkind_mem'
    epoch = 'wait'
    num = j
    reg_idx = 0
    dirnames_mem, tphase_mem = getdirname(ttype,epoch,num,reg_idx,timedep=timedep)
    npts_kind_wait = getnpts(dirnames_mem,tphase_mem,st, eps=eps, doadaptive=doadaptive)
    npts_j = np.array(npts_kind_wait[0])+np.array(npts_kind_wait[1])
    npts_all_mem.append(npts_j)
    dims_all_mem.append(npts_kind_wait[-2])

for j in userange:
    print(j)
    st = 'all'
    ttype = 'pkind_mem'
    epoch = 'wait'
    num = j
    reg_idx = 1
    dirnames_mem, tphase_mem_str = getdirname(ttype,epoch,num,reg_idx,timedep=timedep)
    #print(dirnames_mem)
    npts_kind_wait = getnpts(dirnames_mem,tphase_mem_str,st, eps=eps, doadaptive=doadaptive)
    npts_j = np.array(npts_kind_wait[0])+np.array(npts_kind_wait[1])
    npts_all_mem_str.append(npts_j)
    dims_all_mem_str.append(npts_kind_wait[-2])


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
1
2
3


  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))


4
5
6
7
8
9


  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))
  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))
  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))


10
11
12
13
14
15
16
17
18
19


  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))


20


  sdiffnorm = 2 * np.sqrt(KE) * lam / (np.linalg.norm(crds, axis=1))


In [11]:
dosave = False
doload = True

if dosave:

    savedir = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
    savename = savedir + 'ke_overtraining_mem.dat'

    with open(savename,'wb') as f:
        pickle.dump([npts_all_mem,npts_all_mem_str,dims_all_mem, dims_all_mem_str,tphase_mem, tphase_mem_str],f)
        
if doload:
    savedir = '/scratch/dh148/dynamics/results/rnn/ac/20231003/figs/'
    savename = savedir + 'ke_overtraining_mem.dat'

    with open(savename,'rb') as f:
        [npts_all_mem,npts_all_mem_str,dims_all_mem, dims_all_mem_str, tphase_mem, tphase_mem_str] = pickle.load(f)