## In this notebook we generate the data for the model update task.
We partitioned the notebook in three parts
1. Generation of the stimulus signal, both the bottom up as well as the top down stimulus.


2. Generation of the network, in this step we can let the network if desired, for now we don't.
3. Generation of the snapshots for final analysis.

In [1]:
""" Imports """ 

import numpy as np

from src.signals import signal_context_switching_simple

import pickle

# Set random seed
np.random.seed(42)


# 1. Generation of the stimulus signal

In [2]:
""" 
In this note we generate the the stimulus presented to the network
We generate three different stimuli:
    - The first is used to train the network 
        (signal_model_update.pickle)
    - The second and third are used to test the network 
        (signal_model_update_snapshot.pickle and signal_model_update_snapshot_k.pickle)

We plot the stimuli to give an indication of the task the network is trained on.
"""

dt = 0.001 #second
x_duration = 1.5 #second

prediction_offset = 0.1 #second
T=12000 #time steps

tau_stimulus = 0.2 #decay time of the stimulus



# Signal for training
t,bu,td,t1,t2 = signal_context_switching_simple(T, dt, x_duration, 1, prediction_offset,tau=tau_stimulus,k=0)

with open("signal_model_update.pickle", "wb") as file:
    pickle.dump((t,bu,td,t1,t2),file)

# Signal for snapshots
# First one is a complete mismatch
t,bu,td,t1,t2 = signal_context_switching_simple(int(8*x_duration/dt), dt, x_duration, 1, prediction_offset,tau=tau_stimulus, k=0)
with open("signal_model_update_snapshot.pickle", "wb") as file:
    pickle.dump((t,bu,td,t1,t2),file)
# Second one is a partial mismatch (30% change of the stimulus)
t,bu,td,t1,t2 = signal_context_switching_simple(int(8*x_duration/dt), dt, x_duration, 1, prediction_offset,tau=tau_stimulus, k=0.3)
with open("signal_model_update_snapshot_k.pickle", "wb") as file:
    pickle.dump((t,bu,td,t1,t2),file)


# 2. Generate the network

In [3]:
""" This section generates the trained network and saves it """
from src.run_network import init_network
from src.network import Random_Network, Psi_base_mu, mu_inference_step_func, mu_inference_spiking_condition, mu_inference_Psi

# Network definition

d_stim = 3 #dimension of the bottom up stimulus
d_cau = d_stim #dimension of the causal state
d_td = 2 #dimension of the top down stimulus

repetitions = 1 #number of repetitions of the stimulus

n = 2*d_stim*repetitions #number of neurons

log_nu= -2.61
nu = np.exp(-log_nu) # In Hz
tau=0.200 # Decay time of the neural representation in seconds

context_switching_threshold = 0.20023594 #threshold for context switching

Psi = Psi_base_mu(d_stim,d_td,d_cau,n,dt=dt) #initialize base network

D_c = np.concatenate([np.eye(d_stim),-np.eye(d_stim)]).T*0.15 #Initialize D_c, the mapping from neural to causal state space

D_x = np.eye(d_stim) #Initialize D_x, the mapping from causal to bottom up stimulus space
D_td = np.array([ [1, 0], #Initialize D_td, the mapping from causal to top down stimulus space
                  [0, 0],
                  [0, 1]])

beta = 300. #beta placeholder

mu_inference_Psi(Psi,tau,nu=nu, delta_t=1, beta_x=beta,D_c=D_c, D_x = D_x, D_td = D_td, beta_c_low = 1.,beta_c_high=beta, context_switching_threshold=context_switching_threshold)
inference_network = Random_Network(Psi, mu_inference_step_func, mu_inference_spiking_condition)

init_network(inference_network)


100%|██████████| 12000/12000 [00:00<00:00, 17232.39it/s]


# 3. Run the network and generate snapshots

In [4]:
from src.run_network import generate_network_snapshots

generate_network_snapshots(nu,runs=50)


 --- full mismatch --- 


100%|██████████| 50/50 [00:40<00:00,  1.25it/s]


 --- partial mismatch --- 


100%|██████████| 50/50 [00:39<00:00,  1.26it/s]
