In [1]:
# Basic Libraries
import sys
import time
import gc
import random
import copy 
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# PyTorch Libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Data Handling and Image Processing
from torchvision import datasets, transforms

# Visualization Libraries
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from skimage.metrics import structural_similarity as ssim

# Style for Matplotlib
import scienceplots
plt.style.use('science')
plt.style.use(['no-latex'])

# Scientific Computing and Machine Learning
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy.linalg import subspace_angles
from scipy.spatial.distance import cosine
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr

# Custom Modules and Extensions
sys.path.append("../netrep/")
sys.path.append("../svcca/")

import networks as nets  # Contains RNNs
import net_helpers
import mpn_tasks
import helper
import mpn

import scienceplots
plt.style.use('science')
plt.style.use(['no-latex'])

# Memory Optimization
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [2]:
# 0 Red, 1 blue, 2 green, 3 purple, 4 orange, 5 teal, 6 gray, 7 pink, 8 yellow
c_vals = ['#e53e3e', '#3182ce', '#38a169', '#805ad5','#dd6b20', '#319795', '#718096', '#d53f8c', '#d69e2e',] * 10
c_vals_l = ['#feb2b2', '#90cdf4', '#9ae6b4', '#d6bcfa', '#fbd38d', '#81e6d9', '#e2e8f0', '#fbb6ce', '#faf089',] * 10
c_vals_d = ['#9b2c2c', '#2c5282', '#276749', '#553c9a', '#9c4221', '#285e61', '#2d3748', '#97266d', '#975a16',] * 10 
l_vals = ['solid', 'dashed', 'dotted', 'dashdot', '-', '--', '-.', ':', (0, (3, 1, 1, 1)), (0, (5, 10))]
markers_vals = ['o', 'v', '*', '+', '>', '1', '2', '3', '4', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']
linestyles = ["-", "--", "-."]

In [3]:
hyp_dict = {}

In [4]:
# Reload modules if changes have been made to them
from importlib import reload

reload(nets)
reload(net_helpers)

fixseed = False # randomize setting the seed may lead to not perfectly solved results
seed = random.randint(1,1000) if not fixseed else 8 # random set the seed to test robustness by default
print(f"Set seed {seed}")
np.random.seed(seed)
torch.manual_seed(seed)

hyp_dict['task_type'] = 'multitask' # int, NeuroGym, multitask
hyp_dict['mode_for_all'] = "random_batch"
hyp_dict['ruleset'] = 'delaygofamily' # low_dim, all, test

accept_rules = ('fdgo', 'fdanti', 'delaygo', 'delayanti', 'reactgo', 'reactanti', 
                'delaydm1', 'delaydm2', 'dmsgo', 'dmcgo', 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm')


rules_dict = \
    {'all' : ['fdgo', 'reactgo', 'delaygo', 'fdanti', 'reactanti', 'delayanti',
              'dm1', 'dm2', 'contextdm1', 'contextdm2', 'multidm',
              'delaydm1', 'delaydm2', 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm',
              'dmsgo', 'dmsnogo', 'dmcgo', 'dmcnogo'],
     'low_dim' : ['fdgo', 'reactgo', 'delaygo', 'fdanti', 'reactanti', 'delayanti',
                 'delaydm1', 'delaydm2', 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm',
                 'dmsgo', 'dmsnogo', 'dmcgo', 'dmcnogo'],
     'gofamily': ['fdgo', 'fdanti', 'reactgo', 'reactanti', 'delaygo', 'delayanti'],
     'delaygo': ['delaygo'],
     'delaygofamily': ['delaygo', 'delayanti'],
     'fdgo': ['fdgo'],
     'fdfamily': ['fdgo', 'fdanti'],
     'reactgo': ['reactgo'],
     'reactfamily': ['reactgo', 'reactanti'],
     'delaydm1': ['delaydm1'],
     'delaydmfamily': ['delaydm1', 'delaydm2'],
     'dmsgofamily': ['dmsgo', 'dmsnogo'],
     'dmsgo': ['dmsgo'],
     'dmcgo': ['dmcgo'],
     'contextdelayfamily': ['contextdelaydm1', 'contextdelaydm2'],
    }
    

# This can either be used to set parameters OR set parameters and train
train = True # whether or not to train the network
verbose = True
hyp_dict['run_mode'] = 'minimal' # minimal, debug
hyp_dict['chosen_network'] = "dmpn"

# suffix for saving images
# inputadd, Wfix, WL2, hL2
# inputrandom, Wtrain
# noise001
# largeregularization
# trainetalambda

mpn_depth = 1
n_hidden = 200

hyp_dict['addon_name'] = "inputrandom+Wtrain+WL2+hL2"
hyp_dict['addon_name'] += f"+hidden{n_hidden}"

# for coding 
if hyp_dict['chosen_network'] in ("gru", "vanilla"):
    mpn_depth = 1

def current_basic_params():
    task_params = {
        'task_type': hyp_dict['task_type'],
        'rules': rules_dict[hyp_dict['ruleset']],
        'dt': 40, # ms, directly influence sequence lengths,
        'ruleset': hyp_dict['ruleset'],
        'n_eachring': 8, # Number of distinct possible inputs on each ring
        'in_out_mode': 'low_dim',  # high_dim or low_dim or low_dim_pos (Robert vs. Laura's paper, resp)
        'sigma_x': 0.00, # Laura raised to 0.1 to prevent overfitting (Robert uses 0.01)
        'mask_type': 'cost', # 'cost', None
        'fixate_off': True, # Second fixation signal goes on when first is off
        'task_info': True, 
        'randomize_inputs': False,
        'n_input': 20, # Only used if inputs are randomized,
        'modality_diff': False,
        'label_strength': False, 
        'long_delay': 'normal',
        'long_response': 'normal',
        'adjust_task_prop': True,
        'adjust_task_decay': 0.9, 
    }

    print(f"Fixation_off: {task_params['fixate_off']}; Task_info: {task_params['task_info']}")

    train_params = {
        'lr': 1e-3,
        'n_batches': 128,
        'batch_size': 128,
        'gradient_clip': 10,
        'valid_n_batch': 100,
        'n_datasets': 10000, # Number of distinct batches
        'valid_check': None, 
        'n_epochs_per_set': 1, # longer/shorter training
        'weight_reg': 'L2',
        'activity_reg': 'L2', 
        'reg_lambda': 1e-4,
    }

    if not train: # some 
        assert train_params['n_epochs_per_set'] == 0

    net_params = {
        'net_type': hyp_dict['chosen_network'], # mpn1, dmpn, vanilla
        'n_neurons': [1] + [n_hidden] * mpn_depth + [1],
        'output_bias': False, # Turn off biases for easier interpretation
        'loss_type': 'MSE', # XE, MSE
        'activation': 'tanh', # linear, ReLU, sigmoid, tanh, tanh_re, tukey, heaviside
        'cuda': True,
        'monitor_freq': train_params["n_epochs_per_set"],
        'monitor_valid_out': True, # Whether or not to save validation output throughout training
        'output_matrix': '',# "" (default); "untrained", or "orthogonal"
        'input_layer_add': True, 
        'input_layer_add_trainable': False, # revise this is effectively to [randomize_inputs], tune this
        'input_layer_bias': False, 
        'input_layer': "trainable", # for RNN only
        'acc_measure': 'stimulus', 
        
        # for one-layer MPN, GRU or Vanilla
        'ml_params': {
            'bias': True, # Bias of layer
            'mp_type': 'mult',
            'm_update_type': 'hebb_assoc', # hebb_assoc, hebb_pre
            'eta_type': 'scalar', # scalar, pre_vector, post_vector, matrix
            'eta_train': False,
            # 'eta_init': 'mirror_gaussian', #0.0,
            'lam_type': 'scalar', # scalar, pre_vector, post_vector, matrix
            'm_time_scale': 4000, # ms, sets lambda
            'lam_train': False,
            'W_freeze': False, # different combination with [input_layer_add_trainable]
        },

        # Vanilla RNN params
        'leaky': True,
        'alpha': 0.2,
    }

    # Ensure the two options are *not* activated at the same time
    assert not (task_params["randomize_inputs"] and net_params["input_layer_add"]), (
        "task_params['randomize_inputs'] and net_params['input_layer_add'] cannot both be True."
    )

    # for multiple MPN layers, assert 
    if mpn_depth > 1:
        for mpl_idx in range(mpn_depth - 1):
            assert f'ml_params{mpl_idx}' in net_params.keys()

    # actually I don't think it is needed
    # putting here to warn the parameter checking every time 
    # when switching network
    if hyp_dict['chosen_network'] in ("gru", "vanilla"):
        assert f'ml_params' in net_params.keys()

    return task_params, train_params, net_params

task_params, train_params, net_params = current_basic_params()

shift_index = 1 if not task_params['fixate_off'] else 0

if hyp_dict['task_type'] in ('multitask',):
    task_params, train_params, net_params = mpn_tasks.convert_and_init_multitask_params(
        (task_params, train_params, net_params)
    )

    net_params['prefs'] = mpn_tasks.get_prefs(task_params['hp'])

    print('Rules: {}'.format(task_params['rules']))
    print('  Input size {}, Output size {}'.format(
        task_params['n_input'], task_params['n_output'],
    ))
else:
    raise NotImplementedError()

if net_params['cuda']:
    print('Using CUDA...')
    device = torch.device('cuda')
else:
    print('Using CPU...')
    device = torch.device('cpu')

# how many epoch each dataset will be trained on
epoch_multiply = train_params["n_epochs_per_set"]

Set seed 114
Fixation_off: True; Task_info: True
Rules: ['delaygo', 'delayanti']
  Input size 8, Output size 3
Using CPU...


In [5]:
hyp_dict["mess_with_training"] = False

if hyp_dict['mess_with_training']:
    hyp_dict['addon_name'] += "messwithtraining"

params = task_params, train_params, net_params

if net_params['net_type'] == 'mpn1':
    netFunction = mpn.MultiPlasticNet
elif net_params['net_type'] == 'dmpn':
    netFunction = mpn.DeepMultiPlasticNet
elif net_params['net_type'] == 'vanilla':
    netFunction = nets.VanillaRNN
elif net_params['net_type'] == 'gru':
    netFunction = nets.GRU

In [6]:
test_n_batch = train_params["valid_n_batch"]
color_by = "stim" # or "resp" 

task_random_fix = True
if task_random_fix:
    print(f"Align {task_params['rules']} With Same Time")

if task_params['task_type'] in ('multitask',): # Test batch consists of all the rules
    task_params['hp']['batch_size_train'] = test_n_batch
    # using homogeneous cutting off
    test_mode_for_all = "random"
    # ZIHAN
    # generate test data using "random"
    test_data, test_trials_extra = mpn_tasks.generate_trials_wrap(task_params, test_n_batch, \
                rules=task_params['rules'], mode_input=test_mode_for_all, fix=task_random_fix
    )
    _, test_trials, test_rule_idxs = test_trials_extra

    task_params_attractor = copy.deepcopy(task_params)
    task_params_attractor["long_delay"] = "long"
    test_data_attractor, test_trials_extra_attractor = mpn_tasks.generate_trials_wrap(task_params_attractor, test_n_batch, \
                                                                                      rules=task_params_attractor['rules'], \
                                                                                      mode_input=test_mode_for_all, fix=task_random_fix)
    
    _, test_trials_attractor, test_rule_idxs_attractor = test_trials_extra_attractor

    task_params['dataset_name'] = 'multitask'

    if task_params['in_out_mode'] in ('low_dim_pos',):
        output_dim_labels = ('Fixate', 'Cos', '-Cos', 'Sin', '-Sin')
    elif task_params['in_out_mode'] in ('low_dim',):
        output_dim_labels = ('Fixate', 'Cos', 'Sin')
    else:
        raise NotImplementedError()

    def generate_response_stimulus(task_params, test_trials): 
        """
        """
        labels_resp, labels_stim = [], []
        for rule_idx, rule in enumerate(task_params['rules']):
            print(rule)
            if rule in accept_rules:
                if hyp_dict['ruleset'] in ('dmsgo', 'dmcgo'):
                    labels.append(test_trials[rule_idx].meta['matches'])
                else:
                    labels_resp.append(test_trials[rule_idx].meta['resp1'])
                    labels_stim.append(test_trials[rule_idx].meta['stim1']) 
    
            else:
                raise NotImplementedError()
                
        labels_resp = np.concatenate(labels_resp, axis=0).reshape(-1,1)
        labels_stim = np.concatenate(labels_stim, axis=0).reshape(-1,1)

        return labels_resp, labels_stim

    labels_resp, labels_stim = generate_response_stimulus(task_params, test_trials)

labels = labels_stim if color_by == "stim" else labels_resp
    
test_input, test_output, test_mask = test_data
test_input_attractor, test_output_attractor, test_mask_attractor = test_data_attractor
print(test_input_attractor.shape)
print(test_output_attractor.shape)

permutation = np.random.permutation(test_input.shape[0])
test_input = test_input[permutation]
test_output = test_output[permutation]
test_mask = test_mask[permutation]
labels = labels[permutation]

test_input_np = test_input.detach().cpu().numpy()
test_output_np = test_output.detach().cpu().numpy()

# Total number of batches, might be different than test_n_batch
# this should be the same regardless of variety of test_input
n_batch_all = test_input_np.shape[0] 

def find_task(task_params, test_input_np, shift_index):
    """
    """
    test_task = [] # which task
    for batch_idx in range(test_input_np.shape[0]):
        
        if task_params["randomize_inputs"]: 
            test_input_np_ = test_input_np @ np.linalg.pinv(task_params["randomize_matrix"])
        else: 
            test_input_np_ = test_input_np
            
        task_label = test_input_np_[batch_idx, 0, 6-shift_index:]
        # task_label_index = np.where(task_label == 1)[0][0]
        
        # tol = 1e-3      
        # mask = np.isclose(task_label, 1, atol=tol)
        task_label = np.asarray(task_label)       
        dist = np.abs(task_label - 1)     
        mask = dist == dist.min() 
        
        indices = np.where(mask)[0]
        
        if indices.size:                
            task_label_index = indices[0]   
        else:
            raise ValueError("No entry close enough to 1 found")
            
        test_task.append(task_label_index)

    return test_task  

test_task = find_task(task_params, test_input_np, shift_index)
test_task_attractor = find_task(task_params_attractor, test_input_attractor.detach().cpu().numpy(), shift_index)

Align ['delaygo', 'delayanti'] With Same Time
rng reset with seed 9643
rng reset with seed 9643
rng reset with seed 9643
rng reset with seed 9643
delaygo
delayanti
torch.Size([200, 802, 8])
torch.Size([200, 802, 3])


In [7]:
# we use net at different training stage on the same test_input
start_time = time.time()
net, _, (counter_lst, netout_lst, db_lst, Winput_lst, Winputbias_lst,\
         Woutput_lst, Wall_lst, marker_lst, loss_lst, acc_lst), _ = net_helpers.train_network(params, device=device, verbose=verbose,
                                                                                              train=train, hyp_dict=hyp_dict,\
                                                                                              netFunction=netFunction,\
                                                                                              test_input=[test_input, test_input_attractor],
                                                                                              print_frequency=100)

end_time = time.time()
print(f"Running Time: {end_time - start_time}")
counter_lst = [x * epoch_multiply + 1 for x in counter_lst] # avoid log plot issue    

MultiPlastic Net:
  output neurons: 3
  Act: tanh

  Input Layer Frozen.
  MP Layer1 parameters:
    n_neurons - input: 200, output: 200
    M matrix parameters:    update bounds - Max mult: 1.0, Min mult: -1.0
      type: mult // Update - type: hebb_assoc // Act fn: linear
      Eta: scalar (fixed) // Lambda: scalar (fixed) // Lambda_max: 0.99 (tau: 4.0e+03)
  No Hidden Recurrency.
Trainable parameters: 40,800
W_output: (3, 200)
mp_layer1.W: (200, 200)
mp_layer1.b: (200,)
task_params['rules_probs']: [0.5 0.5]
Rule: delaygo
Rule delaygo seq_len 103, max_seq_len 103
inputs_all paddled: (128, 103, 8)
inputs_all: torch.Size([128, 103, 8])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

In [None]:
if hyp_dict['chosen_network'] == "dmpn":
    if net_params["input_layer_add"]:
        fignorm, axsnorm = plt.subplots(1,1,figsize=(4,4))
        axsnorm.plot(counter_lst, [np.linalg.norm(Winput_matrix) for Winput_matrix in Winput_lst], "-o")
        axsnorm.set_xscale("log")
        axsnorm.set_ylabel("Frobenius Norm")

In [None]:
# sanity check, if W_freeze, then the recorded W matrix for the modulation layer should not be changed
if net_params["ml_params"]["W_freeze"]: 
    assert np.allclose(Wall_lst[-1][0], Wall_lst[0][0])

if net_params["input_layer_bias"]: 
    assert net_params["input_layer_add"] is True 

In [None]:
if train:
    fig, ax = plt.subplots(1,1,figsize=(3,3))
    ax.plot(net.hist['iters_monitor'][1:], net.hist['train_acc'][1:], color=c_vals[0], label='Full train accuracy')
    ax.plot(net.hist['iters_monitor'][1:], net.hist['valid_acc'][1:], color=c_vals[1], label='Full valid accuracy')
    if net.weight_reg is not None:
        ax.plot(net.hist['iters_monitor'], net.hist['train_loss_output_label'], color=c_vals_l[0], zorder=-1, label='Output label')
        ax.plot(net.hist['iters_monitor'], net.hist['train_loss_reg_term'], color=c_vals_l[0], zorder=-1, label='Reg term', linestyle='dashed')
        ax.plot(net.hist['iters_monitor'], net.hist['valid_loss_output_label'], color=c_vals_l[1], zorder=-1, label='Output valid label')
        ax.plot(net.hist['iters_monitor'], net.hist['valid_loss_reg_term'], color=c_vals_l[1], zorder=-1, label='Reg valid term', linestyle='dashed')
    
    # ax.set_yscale('log')
    ax.legend()
    ax.set_ylim([0.0, 1.05])
    # ax.set_ylabel('Loss ({})'.format(net.loss_type))
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('# Batches')
    plt.savefig(f"./twotasks/loss_{hyp_dict['ruleset']}_seed{seed}_{hyp_dict['addon_name']}.png", dpi=300)
    
print('Done!')

In [None]:
if train:
    net_helpers.net_eta_lambda_analysis(net, net_params, hyp_dict)

In [None]:
use_finalstage = False
if use_finalstage:
    # plotting output in the validation set
    net_out, db = net.iterate_sequence_batch(test_input, run_mode='track_states')
    W_output = net.W_output.detach().cpu().numpy()

    W_all_ = []
    for i in range(len(net.mp_layers)):
        W_all_.append(net.mp_layers[i].W.detach().cpu().numpy())
    W_ = W_all_[0]
    
else:
    ind = len(marker_lst)-1 
    # ind = 0
    network_at_percent = (marker_lst[ind]+1)/train_params['n_datasets']*100
    print(f"Using network at {network_at_percent}%")
    # by default using the first test_input 
    net_out = netout_lst[0][ind]
    db = db_lst[0][ind]
    W_output = Woutput_lst[ind]
    W_ = Wall_lst[ind][0]

In [None]:
def plot_input_output(test_input_np, net_out, test_output_np, test_task=None, tag="", batch_num=5):
    """
    """
    test_input_np = helper.to_ndarray(test_input_np)
    net_out = helper.to_ndarray(net_out)
    test_output_np = helper.to_ndarray(test_output_np)
    
    fig_all, axs_all = plt.subplots(batch_num,2,figsize=(4*2,batch_num*2))
    
    if test_output_np.shape[-1] == 1:
        for batch_idx, ax in enumerate(axs):
            ax.plot(net_out[batch_idx, :, 0], color=c_vals[batch_idx])
            ax.plot(test_output_np[batch_idx, :, 0], color=c_vals_l[batch_idx])
    
    else:
        for batch_idx in range(batch_num):
            for out_idx in range(test_output_np.shape[-1]):
                axs_all[batch_idx,0].plot(net_out[batch_idx, :, out_idx], color=c_vals[out_idx], label=out_idx)
                axs_all[batch_idx,0].plot(test_output_np[batch_idx, :, out_idx], color=c_vals_l[out_idx], linewidth=5, alpha=0.5)
                if test_task is not None: 
                    axs_all[batch_idx,0].set_title(f"{task_params['rules'][test_task[batch_idx]]}")
                axs_all[batch_idx,0].legend()
    
            input_batch = test_input_np[batch_idx,:,:]
            if task_params["randomize_inputs"]: 
                input_batch = input_batch @ np.linalg.pinv(task_params["randomize_matrix"])
            for inp_idx in range(input_batch.shape[-1]):
                axs_all[batch_idx,1].plot(input_batch[:,inp_idx], color=c_vals[inp_idx], label=inp_idx)
                if test_task is not None: 
                    axs_all[batch_idx,1].set_title(f"{task_params['rules'][test_task[batch_idx]]}")
                axs_all[batch_idx,1].legend()

    for ax in axs_all.flatten(): 
        ax.set_ylim([-2, 2])
    fig_all.tight_layout()
    fig_all.savefig(f"./twotasks/lowD_{hyp_dict['ruleset']}_{hyp_dict['chosen_network']}_seed{seed}_{hyp_dict['addon_name']}_{tag}.png", dpi=300)

plot_input_output(test_input_np, net_out, test_output_np, test_task, tag="")

In [None]:
plot_input_output(test_input_attractor, netout_lst[1][ind], test_output_attractor, test_task_attractor, tag="long")

In [None]:
def dimensionality_measure(W):
    """
    Dimensionality in recurrent spiking networks: Global trends in activity and local origins in
    connectivity (Equation 3)
    "The dimensionality is a weighted measure of the number of axes explored by that cloud"
    Recanatesi, et al., 2019
    return value in range of (0, 1]
    """
    covW = np.cov(W)
    assert covW.shape[0] == n_hidden
    eigenvalues, eigenvectors = np.linalg.eig(covW)
    numerator = np.sum(eigenvalues) ** 2
    denominator = np.sum(eigenvalues ** 2)
    return (numerator / denominator) / W.shape[0]

In [None]:
labels_resp, labels_stim = generate_response_stimulus(task_params_attractor, test_trials_attractor)
labels_attractor = labels_stim if color_by == "stim" else labels_resp

label_task_comb_attractor = []
for i in range(len(labels_attractor)):
    label_task_comb_attractor.append([labels_attractor[i][0], test_task_attractor[i]])
label_task_comb_attractor = np.array(label_task_comb_attractor)

In [None]:
def sample_non_nan(arr, k):
    """
    Pick `k` distinct (non-NaN) numbers from a 2-D NumPy array.
    """
    pool = arr[~np.isnan(arr)]            # flatten & keep only real numbers
    if k > pool.size:                     # ensure enough unique values
        raise ValueError("k exceeds number of non-NaN entries.")
    return np.random.choice(pool, k, replace=False).tolist()

In [None]:
def analyze_similarity(Ms_orig, hs, net_params, label_task_comb, checktime, compare="modulation"): 
    """
    """
    inverse_modulation_ss_dt = []
    inverse_modulation_sr_dt = []
    inverse_modulation_st_ds = [[], []]
    modulation_save = [[],[]]

    # same stimulus (effectively anti-response), different task
    for k in range(8):
        ind1 = [i for i, lst in enumerate(label_task_comb) if np.array_equal(lst, [k, 0])]
        ind2 = [i for i, lst in enumerate(label_task_comb) if np.array_equal(lst, [k, 1])]
        ll = min(len(ind1), len(ind2))

        if net_params["input_layer_add"]:
            win = net.W_initial_linear.weight.data.detach().cpu().numpy()
        else: 
            win = None 
        
        if compare == "modulation": 
            Ms1_change_stimulus = [((Ms_orig[ind1[i],checktime,:,:]) @ win)[:,0].flatten() if win is not None else (Ms_orig[ind1[i],checktime,:,:])[:,0].flatten() for i in range(ll)]
            Ms2_change_stimulus = [((Ms_orig[ind2[i],checktime,:,:]) @ win)[:,0].flatten() if win is not None else (Ms_orig[ind2[i],checktime,:,:])[:,0].flatten() for i in range(ll)]
        elif compare == "hidden": 
            Ms1_change_stimulus = [hs[ind1[i],checktime,:].flatten() for i in range(ll)]
            Ms2_change_stimulus = [hs[ind2[i],checktime,:].flatten() for i in range(ll)]
        
        inverse_modulation_ss_dt.append(np.mean(
            [1 - cosine(Ms1_change_stimulus[i], Ms2_change_stimulus[j]) for i in range(len(Ms1_change_stimulus)) for j in range(len(Ms2_change_stimulus))]
        ))

        modulation_save[0].append(Ms1_change_stimulus)
        modulation_save[1].append(Ms2_change_stimulus)

    # same response, different task 
    for k in range(8):
        ind1 = [i for i, lst in enumerate(label_task_comb) if np.array_equal(lst, [k, 0])]
        ind2 = [i for i, lst in enumerate(label_task_comb) if np.array_equal(lst, [(k + 4) % 8, 1])]
        ll = min(len(ind1), len(ind2))

        if compare == "modulation": 
            Ms1_change_stimulus = [((Ms_orig[ind1[i],checktime,:,:]) @ win)[:,0].flatten() if win is not None else ((Ms_orig[ind1[i],checktime,:,:]))[:,0].flatten() for i in range(ll)]
            Ms2_change_stimulus = [((Ms_orig[ind2[i],checktime,:,:]) @ win)[:,0].flatten() if win is not None else ((Ms_orig[ind2[i],checktime,:,:]))[:,0].flatten() for i in range(ll)]
        elif compare == "hidden": 
            Ms1_change_stimulus = [hs[ind1[i],checktime,:].flatten() for i in range(ll)]
            Ms2_change_stimulus = [hs[ind2[i],checktime,:].flatten() for i in range(ll)]
        
        inverse_modulation_sr_dt.append(np.mean(
            [1 - cosine(Ms1_change_stimulus[i], Ms2_change_stimulus[j]) for i in range(len(Ms1_change_stimulus)) for j in range(len(Ms2_change_stimulus))]
        ))

    # same task, different stimulus 
    repeat = 100
    modulation_matrices_all = [] 
    for _ in range(repeat): 
        modulation_matrices = [
            np.full((len(modulation_save[0]), len(modulation_save[0])), np.nan), 
            np.full((len(modulation_save[0]), len(modulation_save[0])), np.nan)
        ]
        for i in range(len(modulation_save[0])):
            for j in range(i+1, len(modulation_save[0])):
                modulation_matrices[0][i,j] = 1 - cosine(random.choice(modulation_save[0][i]), random.choice(modulation_save[0][j]))
                modulation_matrices[1][i,j] = 1 - cosine(random.choice(modulation_save[1][i]), random.choice(modulation_save[1][j]))
                
        modulation_matrices_all.append([np.nanmean(sample_non_nan(modulation_matrices[0], 8)),
                                        np.nanmean(sample_non_nan(modulation_matrices[1], 8))
                                       ])

    modulation_matrices_all = np.array(modulation_matrices_all)

    result = [[np.mean(inverse_modulation_ss_dt), np.std(inverse_modulation_ss_dt)], \
              [np.mean(inverse_modulation_sr_dt), np.std(inverse_modulation_sr_dt)], \
               # helper.sample_upper_means(modulation_matrices[0], k=8, n_iter=10), \
               # helper.sample_upper_means(modulation_matrices[1], k=8, n_iter=10)
               [np.mean(modulation_matrices_all[:,0]), np.std(modulation_matrices_all[:,0])], 
               [np.mean(modulation_matrices_all[:,1]), np.std(modulation_matrices_all[:,1])]
              
             ]

    return result

In [None]:
# here db is selected based on learning stage selection 

layer_index = 0 # 1 layer MPN 
if net_params["input_layer_add"]:
    layer_index += 1 
    
def modulation_extraction(test_input, db, layer_index, cuda=False):
    """
    Extracts modulation tensors from `db` and returns:
        Ms:      (batch, seq, features) reshaped version of M
        Ms_orig: original M (no reshape)
        hs:      (batch, seq, features) reshaped version of hidden
        bs:      bias vector/matrix as-is (or concatenated if list)
    """

    def _to_numpy(x):
        # Convert torch.Tensor -> numpy, otherwise np.asarray
        try:
            import torch
            if isinstance(x, torch.Tensor):
                return x.detach().cpu().numpy()
        except Exception:
            pass
        return np.asarray(x)

    def _concat_last(x):
        # If list/tuple of arrays: concatenate on last axis; else return as-is
        return np.concatenate(x, axis=-1) if isinstance(x, (list, tuple)) else x

    n_batch, max_seq_len = test_input.shape[0], test_input.shape[1]

    # M
    M_raw = _concat_last(_to_numpy(db[f"M{layer_index}"]))
    Ms = M_raw.reshape(n_batch, max_seq_len, -1)
    Ms_orig = M_raw  # unreshaped

    # b
    bs = _concat_last(_to_numpy(db[f"b{layer_index}"]))

    # hidden
    H_raw = _concat_last(_to_numpy(db[f"hidden{layer_index}"]))
    hs = H_raw.reshape(n_batch, max_seq_len, -1)

    return Ms, Ms_orig, hs, bs

In [None]:
# across training stage
result_attractor_all = [] 
pr_all = [] 
time_stamps = {} 
for db_attractor in db_lst[1]:
    _, M_attractor, h_attractor, _, = modulation_extraction(test_input_attractor, db_attractor, layer_index)

    prs = [dimensionality_measure(h_attractor[i,:,:].T) for i in range(h_attractor.shape[0])]
    pr_all.append([np.mean(prs), np.std(prs)])
    
    # to handle noise, find the time when fixation is off
    checktime_sample = test_input_attractor[0,:,0].detach().cpu()
    mask = checktime_sample < 0.5                          
    idx = torch.nonzero(mask, as_tuple=False) 
    checktime_attractor = idx[0].item()  

    time_stamps["delay_end"] = checktime_attractor - 2 # a little bit before the fixation off
    
    result_attractor = analyze_similarity(M_attractor, h_attractor, net_params, label_task_comb_attractor, \
                                          checktime=checktime_attractor, compare="hidden")

    result_attractor_all.append(result_attractor)

In [None]:
pr_all = np.array(pr_all)
figpr, axspr = plt.subplots(1,1,figsize=(6,3))
axspr.plot(counter_lst, pr_all[:,0], "-o", color=c_vals[0])
axspr.fill_between(counter_lst, pr_all[:,0]-pr_all[:,1], pr_all[:,0]+pr_all[:,1], color=c_vals_l[0], alpha=0.5)
axspr.set_xscale("log")
axspr.set_xlabel("# Dataset", fontsize=15)
axspr.set_ylabel("Normalized \nParticipation Ratio", fontsize=15)
axspr.tick_params(axis="y", labelsize=12)
axspr.tick_params(axis="x", labelsize=12)
figpr.savefig(f"./twotasks/pr_{hyp_dict['ruleset']}_seed{seed}_{hyp_dict['addon_name']}.png", dpi=300)

In [None]:
figattractor, axsattractor = plt.subplots(1,1,figsize=(4,4))
break_names = ["same stimulus", "same response", "task 1 different stimulus", "task 2 different stimulus"]
for i in range(len(result_attractor_all[0])): 
    mean, std = [rs[i][0] for rs in result_attractor_all], [rs[i][1] for rs in result_attractor_all]
    axsattractor.plot(counter_lst, mean, "-o", color=c_vals[i], label=f"{break_names[i]}")
    axsattractor.fill_between(counter_lst, [mean[i]-std[i] for i in range(len(mean))],\
                              [mean[i]+std[i] for i in range(len(mean))], alpha=0.5, color=c_vals_l[i])
axsattractor.set_xscale("log")
axsattractor.legend()
figattractor.savefig(f"./twotasks/attractor_{hyp_dict['ruleset']}_seed{seed}_{hyp_dict['addon_name']}.png", dpi=300)

In [None]:
# across different timestamp 
stimulus_end = None 
chosen_batch = 0
while stimulus_end is None: 
    try: 
        input_part = test_input_attractor[chosen_batch,:,2:2+4].detach().cpu().numpy()
        input_part_sum = np.sum(input_part, axis=1)
        stimulus_end = np.where(input_part_sum > 0.5)[0][-1]
        stimulus_start = np.where(input_part_sum > 0.5)[0][0] - 1
    except IndexError: 
        chosen_batch += 1

time_stamps["stimulus_start"] = stimulus_start
time_stamps["stimulus_end"] = stimulus_end
time_stamps["delay_start"] = stimulus_end + 1
time_stamps["trial_end"] = len(input_part_sum) - 1

_, M_attractor_end, h_attractor_end, _ = modulation_extraction(test_input_attractor, db_lst[1][-1], layer_index)
result_attractor_end_all = {} 
for key in time_stamps.keys(): 
    result_attractor = analyze_similarity(M_attractor_end, h_attractor_end, net_params, label_task_comb_attractor, \
                                      checktime=time_stamps[key], compare="hidden")
    result_attractor_end_all[key] = result_attractor

In [None]:
# _, M_attractor_end, h_attractor_end, _ = modulation_extraction(test_input_attractor, 
#                                                                db_lst[1][-1], layer_index)

# fig, axs = plt.subplots(2,2,figsize=(4*2,4*2))
# for batch_iter in range(h_attractor_end.shape[0]): 
#     h_norms, m_norms = [], []
#     h_corr, m_corr = [], [] 
#     for time_iter in range(1, h_attractor_end.shape[1]): 
#         h_norms.append(np.linalg.norm(h_attractor_end[batch_iter, time_iter, :]))
#         m_norms.append(np.linalg.norm(M_attractor_end[batch_iter, time_iter, :, :]))
#         h_corr.append(np.linalg.norm(h_attractor_end[batch_iter, time_iter-1, :] - h_attractor_end[batch_iter, time_iter, :]))
#         m_corr.append(np.linalg.norm(M_attractor_end[batch_iter, time_iter-1, :, :].flatten() - 
#                                   M_attractor_end[batch_iter, time_iter, :, :].flatten()))
        
#     axs[0,0].plot(h_norms, color=c_vals[label_task_comb_attractor[batch_iter, 1]], alpha=0.1) 
#     axs[1,0].plot(m_norms, color=c_vals[label_task_comb_attractor[batch_iter, 1]], alpha=0.1)
#     axs[0,1].plot(h_corr, color=c_vals[label_task_comb_attractor[batch_iter, 1]], alpha=0.1) 
#     axs[1,1].plot(m_corr, color=c_vals[label_task_comb_attractor[batch_iter, 1]], alpha=0.1)
    
#     for ax in axs.flatten(): 
#         ax.set_xlim([time_stamps["delay_start"], time_stamps["delay_end"]])
#         ax.set_xlabel("Time Steps since Delay Start", fontsize=15)

# fig.tight_layout()
# fig.savefig(f"./twotasks/attractor_{hyp_dict['ruleset']}_seed{seed}_{hyp_dict['addon_name']}.png", dpi=300)

In [None]:
figattractorend, axsattractorend = plt.subplots(1,1,figsize=(6,6))
for i in range(len(result_attractor_end_all["trial_end"])): 
    mean, std = [rs[i][0] for rs in result_attractor_end_all.values()], [rs[i][1] for rs in result_attractor_end_all.values()]
    stages_counter = [i for i in range(len(result_attractor_end_all))]
    axsattractorend.plot(stages_counter, mean, "-o", color=c_vals[i], label=f"{break_names[i]}")
    axsattractorend.fill_between(stages_counter, [mean[i]-std[i] for i in range(len(mean))],\
                              [mean[i]+std[i] for i in range(len(mean))], alpha=0.5, color=c_vals_l[i])
axsattractorend.set_xticks(stages_counter)
axsattractorend.set_xticklabels(list(result_attractor_end_all.keys()), rotation=45, ha="right", fontsize=15)
axsattractorend.legend(fontsize=15, frameon=True, loc="best")
axsattractorend.set_ylabel("Cosine Similarity of Modulation", fontsize=15)
figattractorend.tight_layout()
figattractorend.savefig(f"./twotasks/attractor_stage_{hyp_dict['ruleset']}_seed{seed}_{hyp_dict['addon_name']}.png", dpi=300)

In [None]:
from itertools import chain

def input_interpolation(test_input_attractor, test_output_attractor, label_task_comb_attractor, expand_stimulus=True):
    """
    """
    assert test_input_attractor.shape[0] == label_task_comb_attractor.shape[0] 
    pro_task, anti_task = {}, {} 
    pro_task_answer, anti_task_answer = {}, {} 
    for k in range(8): 
        ind1 = [i for i, lst in enumerate(label_task_comb_attractor) if np.array_equal(lst, [k, 0])]
        ind1_sample = ind1[0]
        pro_task[k] = test_input_attractor[ind1_sample,:,:]
        pro_task_answer[k] = test_output_attractor[ind1_sample,:,:]

        ind2 = [i for i, lst in enumerate(label_task_comb_attractor) if np.array_equal(lst, [k, 1])]
        ind2_sample = ind2[0]
        anti_task[k] = test_input_attractor[ind2_sample,:,:]
        anti_task_answer[k] = test_output_attractor[ind2_sample,:,:]

    # expand with some unseen stimulus
    if expand_stimulus:
        base_len = len(pro_task)          # original size (8)
        for i in range(base_len):
            i1, i2 = i % 8, (i + 1) % 8   # wrap-around indexing
    
            # input dictionaries
            pro_task[base_len + i]  = (pro_task[i1]        + pro_task[i2])        / 2
            anti_task[base_len + i] = (anti_task[i1]       + anti_task[i2])       / 2
    
            # answer dictionaries
            pro_task_answer[base_len + i]  = (pro_task_answer[i1]        + pro_task_answer[i2])        / 2
            anti_task_answer[base_len + i] = (anti_task_answer[i1]       + anti_task_answer[i2])       / 2

        # re-sort the input and output in an interleaved way 
        interleaved_keys = [k for pair in zip(range(base_len), range(base_len, 2*base_len)) for k in pair]

        pro_task = {k: pro_task[k] for k in interleaved_keys}
        anti_task = {k: anti_task[k] for k in interleaved_keys}
        pro_task_answer = {k: pro_task_answer[k] for k in interleaved_keys}
        anti_task_answer = {k: anti_task_answer[k] for k in interleaved_keys}

    
    n = 10 
    alpha_lst = [i/n for i in range(n+1)]

    stacked_pro = torch.stack([pro_task[k] for k in sorted(pro_task)]) 
    stacked_anti = torch.stack([anti_task[k] for k in sorted(anti_task)])
    stacked_pro_answer = torch.stack([pro_task_answer[k] for k in sorted(pro_task_answer)]) 
    stacked_anti_answer = torch.stack([anti_task_answer[k] for k in sorted(anti_task_answer)])
    
    stacked_interpolation = [alpha_lst[i] * stacked_pro + (1 - alpha_lst[i]) * stacked_anti for i in range(len(alpha_lst))]
    stacked_interpolation_ans = [alpha_lst[i] * stacked_pro_answer + (1 - alpha_lst[i]) * stacked_anti_answer for i in range(len(alpha_lst))]               

    return alpha_lst, stacked_interpolation, stacked_interpolation_ans

alpha_lst, stacked_interpolation, stacked_interpolation_answer = input_interpolation(test_input_attractor, \
                                                                                     test_output_attractor, \
                                                                                     label_task_comb_attractor, \
                                                                                     expand_stimulus=False)

In [None]:
# 0 Red, 1 blue, 2 green, 3 purple, 4 orange, 5 teal, 6 gray, 7 pink, 8 yellow
names = ["hidden", "modulation"]
projected_data_all = [] 

for name in names:
    fighs, axshs = plt.subplots(1,3,figsize=(5*3,5*1))
    
    PCA_downsample = 3
        
    Ms, Ms_orig, hs, bs = modulation_extraction(test_input_attractor, db_lst[1][-1], layer_index)
    batch_num = Ms_orig.shape[0]
    
    if name == "modulation": 
        data = Ms
    elif name == "hidden":
        data = hs 
        
    print(f"data.shape: {data.shape}")
    
    pca = PCA(n_components = PCA_downsample, random_state=42)
    n_activity = data.shape[-1] 
    print(f"n_activity: {n_activity}")
    activity_zero = np.zeros((1, n_activity))
    
    mask_task1 = label_task_comb_attractor[:,1] == 1
    mask_task0 = label_task_comb_attractor[:,0] == 0
        
    as_flat_task1_delay = data[mask_task1][:,time_stamps["delay_start"]:time_stamps["delay_end"],:].reshape((-1, n_activity))
    as_flat_delay = data[:,time_stamps["delay_start"]:time_stamps["delay_end"],:].reshape((-1, n_activity))
    as_flat_stimulus = data[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:].reshape((-1, n_activity))
    
    as_flat = data.reshape((-1, n_activity))    
    pca.fit(as_flat)
    
    total_ev_training = pca.explained_variance_ratio_.sum()  
    print(total_ev_training)

    if name == "hidden": 
        wout = net.W_output.detach().cpu().numpy() 
        wout_proj = pca.transform(wout) 
    
    as_pca = pca.transform(as_flat)
    projected_data = as_pca.reshape((data.shape[0], data.shape[1], -1))
    print(projected_data.shape)
    projected_data_all.append(projected_data)
    zeros_pca = pca.transform(activity_zero)
    
    combination = [[0,1],[0,2],[1,2]]
    
    for i in range(batch_num):
        data_batch = projected_data[i,:,:]
        if label_task_comb_attractor[i,1] in (0,1,): 
            for index, comb in enumerate(combination):
                axshs[index].plot(data_batch[time_stamps["stimulus_start"]:time_stamps["trial_end"],comb[0]], data_batch[time_stamps["stimulus_start"]:time_stamps["trial_end"],comb[1]], c=c_vals[label_task_comb_attractor[i,0]], \
                                         linestyle=linestyles[label_task_comb_attractor[i,1]], alpha=0.01)
                axshs[index].scatter(data_batch[time_stamps["stimulus_start"]:time_stamps["delay_start"],comb[0]], data_batch[time_stamps["stimulus_start"]:time_stamps["delay_start"],comb[1]], c=c_vals[label_task_comb_attractor[i,0]], \
                                         marker=markers_vals[1], alpha=0.01)
                axshs[index].scatter(data_batch[time_stamps["delay_start"]:time_stamps["delay_end"],comb[0]], data_batch[time_stamps["delay_start"]:time_stamps["delay_end"],comb[1]], c=c_vals[label_task_comb_attractor[i,0]], \
                                         marker=markers_vals[2], alpha=0.01)
                axshs[index].scatter(data_batch[time_stamps["delay_end"]:time_stamps["trial_end"],comb[0]], data_batch[time_stamps["delay_end"]:time_stamps["trial_end"],comb[1]], c=c_vals[label_task_comb_attractor[i,0]], \
                                         marker=markers_vals[3], alpha=0.01)
                axshs[index].set_xlabel(f"PCA {comb[0]+1}")
                axshs[index].set_ylabel(f"PCA {comb[1]+1}")
        
    fighs.tight_layout()
    fighs.savefig(f"./twotasks/m_pca_{name}_seed{seed}_{hyp_dict['addon_name']}.png", dpi=300)

In [None]:
import plotly.graph_objects as go

for ind, projected_data in enumerate(projected_data_all): 
    fig = go.Figure()
    
    for i in range(batch_num):
        data_batch = projected_data[i, :time_stamps["delay_end"], :]
        fig.add_trace(
            go.Scatter3d(
                x=data_batch[:,0], y=data_batch[:,1], z=data_batch[:,2],
                mode="lines",
                line=dict(width=2, color=c_vals[label_task_comb_attractor[i,0]]),
                opacity=0.5,
                showlegend=False           
            )
        )
    
    # origin point
    fig.add_trace(
        go.Scatter3d(
            x=[zeros_pca[0, 0]], y=[zeros_pca[0, 1]], z=[zeros_pca[0, 2]],
            mode="markers",
            marker=dict(size=4, color="black"),
            showlegend=False
        )
    )

    zero_pt = zeros_pca[0]
    print(zero_pt.shape)
    
    # define the two spanning vesample_non_nanctors (from the origin)
    v1 = wout_proj[0,:]
    v2 = wout_proj[1,:]
    
    # pick a side-length that matches your data’s overall scale
    traj = projected_data[:, :time_stamps["delay_end"], :].reshape(-1, 3)
    plane_half = 0.5 * np.linalg.norm(traj - zero_pt, axis=1).max()           
    
    # build an (almost) orthonormal basis in the v1–v2 plane
    u_hat = v1 / np.linalg.norm(v1)
    v2_proj = v2 - v2.dot(u_hat) * u_hat      
    v_hat = v2_proj / np.linalg.norm(v2_proj)
    
    # four corners of a square patch centred at the origin
    corners = np.array([
        -plane_half*u_hat - plane_half*v_hat,
         plane_half*u_hat - plane_half*v_hat,
         plane_half*u_hat + plane_half*v_hat,
        -plane_half*u_hat + plane_half*v_hat
    ])
    
    fig.add_trace(
        go.Mesh3d(
            x=corners[:, 0],
            y=corners[:, 1],
            z=corners[:, 2],
            i=[0, 0],
            j=[1, 2],
            k=[2, 3],
            opacity=0.25,
            color="lightblue",
            name="spanning plane",
            showscale=False
        )
    )
    
    fig.update_layout(
        scene=dict(
            xaxis_title="PCA 1",
            yaxis_title="PCA 2",
            zaxis_title="PCA 3"
        ),
        width=800,          
        height=800,
        margin=dict(l=0, r=0, t=40, b=0),
        showlegend=False     
    )
    
    fig.show()

    if ind == 0: 
        endpoints = projected_data[:,time_stamps["delay_end"]+1,:]
        figproj, axproj = plt.subplots(1,1,figsize=(4,4))
        for ei in range(endpoints.shape[0]):
            endpoint = endpoints[ei,:] - zero_pt
            u_coord = endpoint.dot(u_hat)
            v_coord = endpoint.dot(v_hat) 
            endpoint_proj  = zero_pt + u_coord*u_hat + v_coord*v_hat
            
            if label_task_comb_attractor[ei,1] == 0: 
                color_index = label_task_comb_attractor[ei,0] 
            else: 
                color_index = (label_task_comb_attractor[ei,0] + 4) % 8 
    
            axproj.scatter(u_coord, v_coord, c=c_vals[color_index], alpha=0.1)
            
        figproj.show()
        


In [None]:
from matplotlib.colors import SymLogNorm
from scipy.spatial import ConvexHull   # only needed for 3-D volume

def ring_length(pts: np.ndarray) -> float:
    """
    """
    diffs = np.diff(pts, axis=0, append=pts[:1])   # close the loop
    return np.linalg.norm(diffs, axis=1).sum()

def ring_volume_3d(pts: np.ndarray) -> float:
    """
    """
    if pts.shape[1] != 3:
        raise ValueError("ring_volume_3d expects a 3-D point set.")

    hull = ConvexHull(pts)                # triangulated convex surface
    return hull.volume                    # signed; take abs if needed

In [None]:
# 0 Red, 1 blue, 2 green, 3 purple, 4 orange, 5 teal, 6 gray, 7 pink, 8 yellow
names = ["hidden", "modulation"]

raw_data_ring = [[], []] 
raw_data_ring_magnitude = [[], []]
projected_data_ring = [[], []]

for name in names:
    fighs, axshs = plt.subplots(1,3,figsize=(5*3,5*1))
    fighsadd, axshsadd = plt.subplots(1,3,figsize=(5*3,5*1))
    
    fig3dfix = go.Figure()
    
    PCA_downsample = 3
    combination = [[0,1],[0,2],[1,2]]
    
    interpolation_label = [i for i in range(len(stacked_interpolation[0]))]
    print(interpolation_label)
    
    def numbered_markers(n):
        """
        Return a list ['\$0\$', '\$1\$', ... '\$(n-1)\$'] that Matplotlib accepts
        as per-point marker styles.
        """
        return [f'${i}$' for i in range(n)]
    
    marker_new = numbered_markers(len(stacked_interpolation))
    
    projected_data_fix_all = []
    
    for (int_index, int_input) in enumerate(stacked_interpolation): 
        stack_output, _, db_intp = net.iterate_sequence_batch(int_input, run_mode='track_states')

        Ms, Ms_orig, hs, bs = modulation_extraction(int_input, db_intp, layer_index, cuda=True)
        batch_num = Ms_orig.shape[0]
    
        if name == "hidden": 
            data = hs
        elif name == "modulation": 
            data = Ms
        print(f"data.shape: {data.shape}")
        n_activity = data.shape[-1]

        # extract the delay period information
        as_flat_delay_ = data[:,time_stamps["delay_start"]:time_stamps["delay_end"],:]
        as_flat_delay = as_flat_delay_.reshape((-1, n_activity))

        # fixed point in original dimension
        as_flat_fixedpoint_raw = data[:,time_stamps["delay_end"],:]

        raw_data_ring[names.index(name)].append(ring_length(as_flat_fixedpoint_raw))
        fixpt_norm = np.linalg.norm(as_flat_fixedpoint_raw, axis=1)
        raw_data_ring_magnitude[names.index(name)].append(fixpt_norm.mean())
        
        as_flat = data.reshape((-1, n_activity))
    
        if int_index == 0: 
            print("Generate New PCA axes")
            pca_delay = PCA(n_components = PCA_downsample, random_state=42)
            activity_zero = np.zeros((1, n_activity))
            pca_delay.fit(as_flat_delay) 
        
        as_pca = pca_delay.transform(as_flat)
        projected_data = as_pca.reshape((data.shape[0], data.shape[1], -1))
    
        projected_data_fix = projected_data[:,time_stamps["delay_end"],:]

        projected_data_ring[names.index(name)].append(ring_volume_3d(projected_data_fix))
        
        projected_data_fix_all.append(projected_data_fix)
        
        for i in range(batch_num):
            data_batch = projected_data_fix[i,:]
            for index, comb in enumerate(combination):
                marker_value = marker_new[int_index] if int_index == 0 or int_index == len(stacked_interpolation)-1 else "o"
                alpha_value = 0.1 if marker_value == "o" else 1.0
                
                axshs[index].scatter(data_batch[comb[0]], data_batch[comb[1]], c=c_vals[interpolation_label[i]], \
                                         marker=marker_value, alpha=alpha_value)
                axshs[index].set_xlabel(f"PCA {comb[0]+1}")
                axshs[index].set_ylabel(f"PCA {comb[1]}")
    
    for index, comb in enumerate(combination):
        select1 = [pa[:,comb[0]] for pa in projected_data_fix_all] 
        min_select1 = min(arr.min() for arr in select1)
        
        select2 = [pa[:,comb[1]] for pa in projected_data_fix_all] 
        min_select2 = min(arr.min() for arr in select2) 

        epsilon = 1 if name == "hidden" else 10
        min_select1 -= epsilon
        min_select2 -= epsilon
        
        indices_lst = [0, 1, 2, -1] 
        for it_idx, it in enumerate(indices_lst):
            xy = projected_data_fix_all[it][:,[comb[0],comb[1]]]  
            num_xy = xy.shape[0] 

            for xy_index in range(num_xy): 
                axshsadd[index].plot([xy[xy_index%num_xy,0], xy[(xy_index+1)%num_xy,0]],\
                                   [xy[xy_index%num_xy,1], xy[(xy_index+1)%num_xy,1]],\
                                   linestyle="--", linewidth=3, color=c_vals_l[it_idx])
    
        for i in range(len(interpolation_label)): 
            fixed_points = np.array([projected_data_fix[i,:] for projected_data_fix in projected_data_fix_all])
            axshsadd[index].plot(fixed_points[:,comb[0]],\
                                 fixed_points[:,comb[1]],\
                                 "-o", c=c_vals[interpolation_label[i]]
                                )
            
            # axshsadd[index].set_title(f"Shift: {min_select1:2f};{min_select2:2f}")

            if index == 0:
                fig3dfix.add_trace(
                    go.Scatter3d(
                        x=np.array(alpha_lst), y=fixed_points[:,0], z=fixed_points[:,1],
                        mode="lines",
                        line=dict(width=2, color=c_vals[interpolation_label[i]]),
                        opacity=0.5,
                        name=f"S{i}",
                        showlegend=True
                    )
                )

    # for ax in axshs: 
    #     ax.set_xscale('symlog')   # region |x| < 1 is kept linear
    #     ax.set_yscale('symlog')

    # for ax in axshsadd: 
    #     ax.set_xscale('log')
    #     ax.set_yscale('log')


    fighs.suptitle(name)
    fighs.tight_layout()
    fighs.savefig(f"./twotasks/m_pca_attractor_{name}_seed{seed}_{hyp_dict['addon_name']}_{int_index}.png", dpi=300)

    fighsadd.suptitle(name) 
    fighsadd.tight_layout() 
    fighsadd.savefig(f"./twotasks/m_pca_attractor_cycle_{name}_seed{seed}_{hyp_dict['addon_name']}_{int_index}.png", dpi=300)

    fig3dfix.update_layout(
        scene=dict(
            xaxis_title="Alpha",
            yaxis_title="PCA 1",
            zaxis_title="PCA 2"
        ),
        width=700,          
        height=700,
        margin=dict(l=0, r=0, t=40, b=0),
        showlegend=True     
    )
    fig3dfix.show()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(6*2,3))
axs[0].plot(alpha_lst, raw_data_ring[0], "-o", color=c_vals[0], label="Hidden Ring Perimeter")
axs[0].plot(alpha_lst, raw_data_ring[1], "-o", color=c_vals[1], label="Modulation Ring Perimeter")
axs[1].plot(alpha_lst, projected_data_ring[0], "-o", color=c_vals[0], label="Hidden Ring Volume")
axs[1].plot(alpha_lst, projected_data_ring[1], "-o", color=c_vals[1], label="Modulation Ring Volume")
for ax in axs: 
    ax.set_yscale("log")
    ax.legend(fontsize=15, frameon=True, loc="best")
    ax.set_xlabel("Interpolation Level", fontsize=15)
    ax.tick_params(axis="y", labelsize=12)
    ax.tick_params(axis="x", labelsize=12)
fig.tight_layout()
fig.savefig(f"./twotasks/m_pca_ring_{name}_seed{seed}_{hyp_dict['addon_name']}_{int_index}.png", dpi=300)

fig, ax = plt.subplots(1,1,figsize=(6,3))
ax.plot(alpha_lst, raw_data_ring_magnitude[0], "-o", color=c_vals[0], label="Hidden Average Magnitude")
ax.plot(alpha_lst, raw_data_ring_magnitude[1], "-o", color=c_vals[1], label="Modulation Average Magnitude")
ax.legend(fontsize=15, frameon=True, loc="best")
ax.set_xlabel("Interpolation Level")
fig.savefig(f"./twotasks/m_pca_ring_magnitude_{name}_seed{seed}_{hyp_dict['addon_name']}_{int_index}.png", dpi=300)

In [None]:
# 0 Red, 1 blue, 2 green, 3 purple, 4 orange, 5 teal, 6 gray, 7 pink, 8 yellow
fig3dresponse_cos = go.Figure()
fig3dresponse_sin = go.Figure() 
fig3dresponse = [fig3dresponse_cos, fig3dresponse_sin]

N = len(stacked_interpolation)

wout = net.W_output.detach().cpu().numpy()
print(wout.shape)

name = "hidden" 

# anti, hybrid of anti and go with equal weight, and go 
anti_go = [stacked_interpolation[0], stacked_interpolation[int((N+1)/2)], stacked_interpolation[-1]]

_, _, db_intp_anti = net.iterate_sequence_batch(anti_go[0], run_mode='track_states')
_, _, db_intp_go = net.iterate_sequence_batch(anti_go[2], run_mode='track_states')

Ms_anti, Ms_orig_anti, hs_anti, bs_anti = modulation_extraction(int_input, db_intp_anti, layer_index)
Ms_go, Ms_orig_go, hs_go, bs_go = modulation_extraction(int_input, db_intp_go, layer_index)

batch_num = Ms_orig_go.shape[0]
    
if name == "hidden": 
    data_anti, data_go = hs_anti, hs_go
elif name == "modulation": 
    data_anti, data_go = Ms_anti, Ms_go

data_all = np.concatenate((data_anti, data_go), axis=0)

n_activity = data_anti.shape[-1]

as_flat_stim = data_anti[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:].reshape((-1, n_activity))
as_flat_stim_all = data_all[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:].reshape((-1, n_activity))

as_flat_anti = data_anti.reshape((-1, n_activity))
as_flat_go = data_go.reshape((-1, n_activity))

pca_stim = PCA(n_components = PCA_downsample, random_state=42)
pca_stim.fit(as_flat_stim)  
# pca_stim.fit(as_flat_stim_all)
# pca_stim.fit(wout)

total_ev_training = pca_stim.explained_variance_ratio_.sum()  
print(total_ev_training)
print(pca_stim.components_.shape)

as_pca_anti = pca_stim.transform(as_flat_anti)
projected_data_anti = as_pca_anti.reshape((data_anti.shape[0], data_anti.shape[1], -1))
as_pca_go = pca_stim.transform(as_flat_go)
projected_data_go = as_pca_go.reshape((data_go.shape[0], data_go.shape[1], -1))

projected_data_stim_anti = projected_data_anti[:,:time_stamps["trial_end"],:]
projected_data_stim_go = projected_data_go[:,:time_stamps["trial_end"],:]

response_anti = (hs_anti[:,:time_stamps["trial_end"],:]) @ wout.T 
response_go = (hs_go[:,:time_stamps["trial_end"],:]) @ wout.T 

for i in range(projected_data_stim_anti.shape[0]): 
    for resp in range(2): 
        # plot the trajectory for anti
        fig3dresponse[resp].add_trace(
            go.Scatter3d(
                x=np.array(projected_data_stim_anti[i,:,0]), y=projected_data_stim_anti[i,:,1], z=response_anti[i,:,resp+1],
                mode="lines",
                line=dict(width=4, color=c_vals[i]),
                opacity=1.0,
                name=f"S{i}",
                showlegend=True
            )
        )

        # plot the end point for anti 
        fig3dresponse[resp].add_trace(
            go.Scatter3d(
                x=[projected_data_stim_anti[i, -1, 0]],
                y=[projected_data_stim_anti[i, -1, 1]],
                z=[response_anti[i, -1, resp+1]],
                mode="markers",
                marker=dict(size=6, color=c_vals[i], symbol="circle"),
                legendgroup=f"S{i}",        
                showlegend=False           
            )
        )

        # plot the trajectory for go 
        fig3dresponse[resp].add_trace(
            go.Scatter3d(
                x=np.array(projected_data_stim_go[i,:,0]), y=projected_data_stim_go[i,:,1], z=response_go[i,:,resp+1],
                mode="lines",
                line=dict(width=4, color=c_vals[(i+4)%8], dash="dash"),
                opacity=1.0,
                name=f"S{i}",
                showlegend=True
            )
        )

        # plot the end point for go 
        fig3dresponse[resp].add_trace(
            go.Scatter3d(
                x=[projected_data_stim_go[i, -1, 0]],
                y=[projected_data_stim_go[i, -1, 1]],
                z=[response_go[i, -1, resp+1]],
                mode="markers",
                marker=dict(size=6, color=c_vals[(i+4) % 8], symbol="diamond"),
                legendgroup=f"S{i}",
                showlegend=False
            )
        )

for resp in range(2): 
    zname = ["cos", "sin"]
    fig3dresponse[resp].update_layout(
        scene=dict(
            xaxis_title="Memoryanti Stimulus PCA 1",
            yaxis_title="Memoryanti Stimulus PCA 2",
            zaxis_title=f"{zname[resp]} theta"
        ),
        width=700,          
        height=700,
        margin=dict(l=0, r=0, t=40, b=0),
        showlegend=True     
    )

    fig3dresponse[resp].show()

In [None]:
N = len(stacked_interpolation)

for name in names: 
    anti_go = [stacked_interpolation[0], stacked_interpolation[int((N+1)/2)], stacked_interpolation[-1]]
    _, _, db_intp_anti = net.iterate_sequence_batch(anti_go[0], run_mode='track_states')
    _, _, db_inp_middle = net.iterate_sequence_batch(anti_go[1], run_mode='track_states')
    _, _, db_intp_go = net.iterate_sequence_batch(anti_go[2], run_mode='track_states')
    
    Ms_anti, Ms_orig_anti, hs_anti, bs_anti = modulation_extraction(int_input, db_intp_anti, layer_index)
    Ms_middle, Ms_orig_middle, hs_middle, bs_middle = modulation_extraction(int_input, db_inp_middle, layer_index)
    Ms_go, Ms_orig_go, hs_go, bs_go = modulation_extraction(int_input, db_intp_go, layer_index)

    batch_num = Ms_orig_go.shape[0]
        
    if name == "hidden": 
        data_anti, data_middle, data_go = hs_anti, hs_middle, hs_go
    elif name == "modulation": 
        data_anti, data_middle, data_go = Ms_anti, Ms_middle, Ms_go
        
    n_activity = data_anti.shape[-1]

    as_flat_stim = data_anti[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:].reshape((-1, n_activity))

    as_flat_anti = data_anti.reshape((-1, n_activity))
    as_flat_middle = data_middle.reshape((-1, n_activity))
    as_flat_go = data_go.reshape((-1, n_activity))

    pca_stim = PCA(n_components = PCA_downsample, random_state=42)
    pca_stim.fit(as_flat_stim) 

    as_pca_anti = pca_stim.transform(as_flat_anti)
    projected_data_anti = as_pca_anti.reshape((data_anti.shape[0], data_anti.shape[1], -1))
    as_pca_middle = pca_stim.transform(as_flat_middle)
    projected_data_middle = as_pca_middle.reshape((data_middle.shape[0], data_middle.shape[1], -1))
    as_pca_go = pca_stim.transform(as_flat_go)
    projected_data_go = as_pca_go.reshape((data_go.shape[0], data_go.shape[1], -1))

    projected_data_stim_anti = projected_data_anti[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:]
    projected_data_stim_middle = projected_data_middle[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:]
    projected_data_stim_go = projected_data_go[:,time_stamps["stimulus_start"]:time_stamps["stimulus_end"],:]

    fig, axs = plt.subplots(1,3,figsize=(4*3,4))
    combination = [[0,1],[0,2],[1,2]]
    for comb_index, comb in enumerate(combination): 
        for i in range(projected_data_stim_anti.shape[0]): 
            axs[comb_index].plot(projected_data_stim_anti[i,:,comb[0]], projected_data_stim_anti[i,:,comb[1]], \
                                color=c_vals[i], linestyle=linestyles[0])
            axs[comb_index].plot(projected_data_stim_middle[i,:,comb[0]], projected_data_stim_middle[i,:,comb[1]], \
                                color=c_vals[i], linestyle=linestyles[1])
            axs[comb_index].plot(projected_data_stim_go[i,:,comb[0]], projected_data_stim_go[i,:,comb[1]], \
                                color=c_vals[i], linestyle=linestyles[2])
    for ax in axs: 
        ax.set_title(name)

    fig.savefig(f"./twotasks/m_pca_stimulus_{name}_seed{seed}_{hyp_dict['addon_name']}_{int_index}.png", dpi=300)

        