In [None]:
import numpy as np
import matplotlib.pyplot as plt
from model import SingleAreaModule, TwoAreaRNN 

In [None]:
# Setting high level simulation parameters
T = 12000
NUM_TRIALS = 1000
LATENT_DIM = 2

In [None]:
ppc_module = SingleAreaModule('ppc',0.35,0.28387,0.15,0)
pfc_module = SingleAreaModule('pfc',0.4182,0.28387,0,0.04)
model_obj = TwoAreaRNN([(ppc_module,pfc_module)])
model_obj.add_external_input([0.00052*np.eye(2,dtype = 'float'),0*np.eye(2,dtype = 'float')])
external_inputs = [25*np.ones((LATENT_DIM,NUM_TRIALS,T)),0*np.ones((LATENT_DIM,NUM_TRIALS,T))]
model_obj.forward(external_inputs)

In [None]:
# Inspecting the network weights
weight_type = 'inp'
for node,neighbors in model_obj.model.items():
    if weight_type == 'fbk':
        print("weights_"+node.name +"={}".format(node.J_fbk))
        for neighbor in neighbors:
            print("weights_"+neighbor.name +"={}".format(neighbor.J_fbk))
    elif weight_type == 'rec':
        print("weights_"+node.name +"={}".format(node.J_rec))
        for neighbor in neighbors:
            print("weights_"+neighbor.name +"={}".format(neighbor.J_rec))
    elif weight_type == 'fwd':
        print("weights_"+node.name +"={}".format(node.J_fwd))
        for neighbor in neighbors:
            print("weights_"+neighbor.name +"={}".format(neighbor.J_fwd))
    elif weight_type == 'inp':
        print("weights_"+node.name +"={}".format(node.J_inp))
        for neighbor in neighbors:
            print("weights_"+neighbor.name +"={}".format(neighbor.J_inp))

In [None]:
for node in model_obj.model.keys():
    print(node.name)

In [None]:
for node in model_obj.model.keys():
    plt.figure(1)
    for k in range(NUM_TRIALS):
        plt.subplot(211)
        plt.plot(np.arange(T), node.latents[0,k,:], 'k')
        plt.subplot(212)
        plt.plot(np.arange(T), node.latents[1,k,:], 'k')
    plt.show()

In [None]:
for node in model_obj.model.keys():
    plt.figure(1)
    cols = ['k','r']
    for k in range(2):
        plt.subplot(211)
        plt.plot(np.arange(T), node.i_noise[0,k,:], cols[k])
        plt.subplot(212)
        plt.plot(np.arange(T), node.i_noise[1,k,:], cols[k])
    plt.show()