# rl5E_lite_par.ipynb

Used to run parallel simulations of the model in `rl5E_lite`, returning the mean reward and the weight matrices for each.

In [1]:
%cd ../..
from draculab import *
import numpy as np
import matplotlib.pyplot as plt
import time
from tools.visualization import plotter
from multiprocessing import Pool

/home/z/projects/draculab/notebook


In [2]:
%cd spinal/rl
from rl5E_lite_from_cfg import rl5E_net

/home/z/projects/draculab/notebook/spinal/rl


In [3]:
# The parameters and their ranges
ranges = {"C_sigma": {"low": 0.01, "high": 1., "default": 0.4 }, # yes
          "C_slope": {"low": 1., "high": 3., "default": 2. },
          "C_thresh": {"low": 0.0, "high": .2, "default":0.2},
          "C_integ_amp": {"low": 0.0, "high": 1.5, "default": 0.0},
          "C_custom_inp_del": {"low": 10, "high": 300, "default": 150}, # seemingly replaced in the code
          "M_des_out_w_abs_sum": {"low": 0.5, "high": 3., "default":1.6},
          "P_mu": {"low": .1, "high": 2., "default":1.},
          "P_inp_gain": {"low": 0.5, "high": 4., "default":2.},
          "V_slope": {"low": 0.3, "high": 2.5, "default":1.5}, #yes
          "V_delta": {"low": .2, "high": 4., "default":1.}, #yes
          "V_thresh": {"low": -0.1, "high": 2., "default":0.}, # no
          "V_td_lrate": {"low": 0.2, "high": 7., "default": 1.5}, # yes
          "V_td_gamma": {"low": 0.1, "high": 9., "default": .6}, # yes
          "V_w_sum": {"low": 10., "high": 100., "default": 60.}, # yes
          "X_slope": {"low": 2., "high": 10., "default":5.}, #yes
          "X_thresh": {"low": -0.2, "high": 1., "default":0.}, # no
          "X_del": {"low": 0., "high": 1.5, "default":0.3},
          "X_l_rate": {"low": 50., "high": 400., "default":200.},
          "X_w_sum": {"low": 10., "high": 80., "default": 30.}, # yes
          "A__M_lrate": {"low": 0.05, "high": 10., "default":5.},
          "A__M_w_sum": {"low": 0.3, "high": 2., "default":.4},
          "A__M_w_max": {"low": 0.2, "high": .4, "default":.3},
          "M__C_lrate": {"low": 1., "high": 400., "default":100.},
         }

# Create a default configuration
cfg = {}
for name in ranges:
    cfg[name] = ranges[name]['default']
    
cfg['sim_time1'] = 20 #4000.  # simulation time with switching X
cfg['sim_time2'] = 20 #500.  # simulation time with free X
cfg['pres_interv'] = 4.  # time per target presentation 
    
#cfg

In [4]:
def eval_config(cfg):
    """ Returns mean reward and weights for a given configuration. 
    
        Two simulations are run, one with switching X for cfg['sim_time1'] seconds,
        and one without switching X for cfg['sim_time2'] seconds.
        
        Args:
            cfg : parameter dictionary to initialze rl5E_net. 
            
        Returns:
            A dictionary with 4 entries:
                'mean_R1' : mean reward in the first (X switch) simulation.
                'mean_R2' : mean reward in the second (X free) simulation.
                'Vw' : L__V weight vector.
                'Xw' : L__X weight vector.
    """
    net, pops_dict = rl5E_net(cfg, 
                          pres_interv=cfg['pres_interv'],
                          rand_w=False,
                          par_heter=0.1,
                          x_switch=True,
                          V_normalize=True,
                          X_normalize=True)
    
    R = pops_dict['R']
    V = pops_dict['V']
    X = pops_dict['X']
    
    times1, data1, plant_data1 = net.run(cfg['sim_time1'])
    R_data1 = np.array(data1[R[0],:])
    mean_R1 = np.mean(R_data1)
    
    net.units[pops_dict['X'][0]].switch = False # stop swtiching
    net.units[V[0]].alpha = 0.1 * net.min_delay # decrease L__V learning rate
    net.units[X[0]].alpha = 0.05                # decrease L__X learning rate
    
    times2, data2, plant_data2 = net.run(cfg['sim_time2'])
    R_data2 = np.array(data2[R[0],:])
    mean_R2 = np.mean(R_data2)
    
    Vw = net.units[V[0]].buffer[1:,-1]
    Xw = net.units[X[0]].buffer[1:,-1]
    
    results = {'mean_R1' : mean_R1,
               'mean_R2' : mean_R2,
               'Vw' : Vw,
               'Xw' : Xw }
    
    return mean_R1


In [5]:
# Create the list of configurations
configs = [cfg, cfg.copy()]
configs[1]['pres_interv'] = 5.

In [8]:
# Parallel runs of 'eval_config'
n_procs = 2
print('Starting %d processes' % (n_procs))
start_time = time.time()
all_results = list(map(eval_config, configs)) 
# with Pool(n_procs) as p:
#     #all_results = list(p.map(eval_config, configs))
#     #all_results = p.map(eval_config, configs)
#     p.close()
#     p.join()
print('****** Processing finished after %s seconds ******' % (time.time() - start_time)) 

Starting 2 processes
t=2.03, t=6.02, t=10.02, t=14.02, t=18.02, 

TypeError: list indices must be integers or slices, not tuple

In [7]:
all_results

NameError: name 'all_results' is not defined

In [None]:
# Initialize L__V connections
net.units[V[0]].buffer[1:,-1] = np.array([2.*np.pi-abs(c[0]-c[1]) for c in
                                          net.units[V[0]].centers])

In [None]:
net.plants[P].mu = .5
for i in [0,1]:
    for syn in net.syns[C[i]]:
        if syn.type == 'rga_21':
            syn.lrate = 50.
            syn.alpha = syn.lrate * net.min_delay

In [None]:
# Setting a static value for X
net.units[X[0]].switch = False # stop switching
net.units[X[0]].thresh = 0.

In [None]:
# Setting, fixing controller weights
# standard values:
# A__M_mat = np.rray([[0.27141863, 0.29511766, 0.23279135, 0.20067437],
#                     [0.23307357, 0.1957229 , 0.28739239, 0.28380952]])

# M__C_mat = np.array([[0.0295056,  2.05108517],
#                      [2.05111951, 0.02950648]])
# limit values:
A__M_mat = np.array([[0.4, .5, 0.08, 0.018],
                    [0.05, 0. , 0.4, 0.5]])

M__C_mat = np.array([[0.,  2.15],
                     [2.15, 0.]])


for m_idx, m_id in enumerate(M):
    for c_idx, c_id in enumerate(C):
        syn_list = net.syns[c_id]
        for syn in syn_list:
            if syn.preID == m_id:
                syn.w = M__C_mat[c_idx, m_idx]
                syn.alpha = 1e-4 # slowing down learning
                break

for a_idx, a_id in enumerate(A):
    for m_idx, m_id in enumerate(M):
        syn_list = net.syns[m_id]
        for syn in syn_list:
            if syn.preID == a_id:
                syn.w = A__M_mat[m_idx, a_idx]
                syn.alpha = 1e-4 # slowing down learning
                break
                

In [None]:
sim_time = 4.
#ratio = 15.1 # cns-amd
ratio = 1.4 # breaker
#ratio = 1.6 # breaker, no L,V
secs2finish = sim_time * ratio
lt1 = time.localtime()
hrs, hrs_rem = divmod(secs2finish, 3600)
mins, mns_rem = divmod(hrs_rem, 60)
xtra_hrs, new_mins = divmod(lt1.tm_min+mins, 60)
print("Expecting to finish at: %d:%d (%d seconds)" % 
      (lt1.tm_hour+hrs+xtra_hrs, new_mins, secs2finish))
start_time = time.time()

times, data, plant_data  = net.run(sim_time)

print('Execution time is %s seconds' % (time.time() - start_time))
lt2 = time.localtime
print("Finished at " + time.strftime('%H:%M'))
data = np.array(data)

# import cProfile
# import pstats
# cProfile.run('times, data, plant_data = net.run(2.)', 'restats')
# prof = pstats.Stats('restats')
# prof.sort_stats('cumulative').print_stats(30)
# data = np.array(data

In [None]:
# running in two stages. Initial high learning rate and viscosity.
sim_time = 150.
#ratio = 15.1 # cns-amd
#ratio = 20. # breaker
ratio = 1.6 # breaker, no L,V
secs2finish = sim_time * ratio
lt1 = time.localtime()
hrs, hrs_rem = divmod(secs2finish, 3600)
mins, mns_rem = divmod(hrs_rem, 60)
xtra_hrs, new_mins = divmod(lt1.tm_min+mins, 60)
print("Expecting to finish firstst simulation at %d:%d (%d seconds)" % 
      (lt1.tm_hour+hrs+xtra_hrs, new_mins, secs2finish))
start_time = time.time()

times, data, plant_data  = net.run(sim_time)

print('Initial execution time is %s seconds' % (time.time() - start_time))
lt2 = time.localtime
print("Finished at " + time.strftime('%H:%M'))
data = np.array(data)

sim_time = 100.
net.plants[0].mu = 0.5
for i in [0,1]:
    for syn in net.syns[C[i]]:
        if syn.type == 'rga_21':
            syn.lrate = 100.
            syn.alpha = syn.lrate * net.min_delay

secs2finish = sim_time * ratio
lt1 = time.localtime()
hrs, hrs_rem = divmod(secs2finish, 3600)
mins, mns_rem = divmod(hrs_rem, 60)
xtra_hrs, new_mins = divmod(lt1.tm_min+mins, 60)
print("Expecting to finish at: %d:%d (%d seconds)" % 
      (lt1.tm_hour+hrs+xtra_hrs, new_mins, secs2finish))
start_time = time.time()

times, data, plant_data  = net.run(sim_time)

print('Second execution time is %s seconds' % (time.time() - start_time))
lt2 = time.localtime
print("Finished at " + time.strftime('%H:%M'))
data = np.array(data)

In [None]:
net.units[SP[0]].set_function(lambda t: des_sf[int(round(t/15.))])

In [None]:
# reducing the scope of the plots
data_back = data
times_back = times
plant_data_back = [np.array([])]
plant_data_back[0] = plant_data[0]

first_idx=0*200
second_idx=1*200
times = times[first_idx:second_idx]
data = data[:, first_idx:second_idx]
plant_data[0] = plant_data[0][first_idx:second_idx,:]

In [None]:
# recover the data
data = data_back
plant_data[0] = plant_data_back[0]
times = times_back

In [None]:
fs = (20,6)

# M
M_fig = plt.figure(figsize=fs)
M_data = np.array(data[M])
plt.plot(times, M_data.transpose())
plt.legend(['M0', 'M1'])
plt.title('M0, M1')
#print(M_data[:,-1])

# MPLEX
MPLEX_fig = plt.figure(figsize=fs) #(30,12))
MPLEX_data = np.array(data[MPLEX])
plt.plot(times, MPLEX_data.transpose())
plt.legend(['MPLEX0', 'MPLEX1'])
plt.title('MPLEX0, MPLEX1')

# V, R
V_fig = plt.figure(figsize=(30,10)) #fs)
V_data = np.array(data[V])
R_data = np.array(data[R])
#R = np.exp(-net.units[V[0]].R_wid * np.abs(MPLEX_data[0,:]-MPLEX_data[1,:]))
plt.plot(times, V_data.transpose())
plt.plot(times, R_data.transpose())
plt.title('V, R')
#plt.plot(times, R, linewidth=4)
plt.legend(['V', 'R'])

# # M--C0 weights
# W_fig1 = plt.figure(figsize=fs)
# w_track_data = np.array(data[M_C0_track])
# plt.plot(times, w_track_data.transpose())
# plt.legend(['M0-C0', 'M1-C0'])
# plt.title('M--C0 weights')

# # A--M0 weights
# W_fig2 = plt.figure(figsize=fs)
# w_track_data2 = np.array(data[A_M0_track])
# plt.plot(times, w_track_data2.transpose())
# plt.legend(['A0-M0', 'A1-M0', 'A2-M0', 'A3-M0'])
# plt.title('A--M0 weights')

plt.show()

In [None]:
# X
X_fig = plt.figure(figsize=fs)
X_data = np.array(data[X])
plt.plot(times, X_data.transpose())
plt.plot(times, 0.5*np.ones(len(times)), 'k', linewidth=1)
#plt.ylim([-0.05,1.05])
#plt.legend(['X'])
plt.title('X')

# SF, SP
SF_fig = plt.figure(figsize=fs)
SF_data = np.array(data[SF])
SP_data = np.array(data[SP])
plt.plot(times, SF_data.transpose(), label='SF')
plt.plot(times, SP_data.transpose(), label='SP', linewidth=4)
plt.legend()
plt.title('SF, SP')
plt.show()
#print('SF = [%f]' % (SF_data[0,-1]))
#print('SP = [%f]' % (SP_data[0,-1]))

# SPF
fs = (20,6)
SPF_fig = plt.figure(figsize=fs)
SPF_data = np.array(data[SPF])
plt.plot(times, SPF_data.transpose())
plt.legend(['SPF0', 'SPF1'])
plt.title('SPF')

In [None]:
# P, DA
pres_interv=5.
P_fig = plt.figure(figsize=fs)
P_data = plant_data[P]
plt.plot(times, P_data[:,0], label='angle')
plt.plot(times, P_data[:,1], label='ang vel')
plt.plot(times, des_angs[(times/pres_interv).astype(int)], label='des_ang')
plt.legend()
plt.title('pendulum')
#print("angle: %f, vel: %f" % (P_data[-1,0],P_data[-1,1]))

# A
A_fig = plt.figure(figsize=fs)
A_data = np.array(data[A])
plt.plot(times, A_data.transpose())
plt.legend(['A0', 'A1', 'A2', 'A3'])
plt.title('A')
#print(A_data[:,-1])

# L
# L_fig = plt.figure(figsize=fs)
# L_data = np.array(data[L])
# plt.plot(times, L_data.transpose())
# plt.title('L')

# C0
C0_fig = plt.figure(figsize=fs)
C0_data = np.array(data[C[0]])
plt.plot(times, C0_data.transpose())
#plt.plot(times, data[dc_track[0]], linewidth=3)
plt.title('C0')
#print(C0_data[-1])

# C1
C1_fig = plt.figure(figsize=fs)
C1_data = np.array(data[C[1]])
plt.plot(times, C1_data.transpose())
#plt.plot(times, data[dc_track[1]], linewidth=3)
plt.title('C1')
#print(C1_data[-1])


In [None]:
max(net.units[V[0]].buffer[1:,-1]) - min(net.units[V[0]].buffer[1:,-1])

In [None]:
net.units[V[0]].buffer[1:,-1]

In [None]:
net.units[V[0]].buffer[1:,-1]
d_fig = plt.figure(figsize=(10,10))
d_ax = plt.subplot(1,1,1)
cs = d_ax.imshow(net.units[V[0]].buffer[1:,-1].reshape(10,10))
d_fig.colorbar(cs)
plt.show()

In [None]:
max(net.units[X[0]].buffer[1:,-1]) - min(net.units[X[0]].buffer[1:,-1])

In [None]:
net.units[V[0]].buffer[1:,-1].sum()

In [None]:
# Visualize L__V weights
n_plots = 20 # number of plots, each at a different time
n_rows = int(np.ceil(np.sqrt(n_plots)))
n_cols = int(np.ceil(n_plots/n_rows))
t_idxs = [int(i) for i in np.linspace(0, len(times)-1, n_plots)]

L__V_fig, L__V_axs = plt.subplots(n_rows, n_cols, figsize=(fs[0], n_cols*fs[1]))
for i_plot in range(n_plots):
    row, col = np.divmod(i_plot, n_cols)
    ax = L__V_axs[row][col]
    cs = ax.imshow(data[v_track][:,t_idxs[i_plot]].reshape(No2,No2))
    ax.set_title('t='+str(times[t_idxs[i_plot]]))
    
plt.show()

In [None]:
# Visualize L__X weights
n_plots = 20 # number of plots, each at a different time
n_rows = int(np.ceil(np.sqrt(n_plots)))
n_cols = int(np.ceil(n_plots/n_rows))
t_idxs = [int(i) for i in np.linspace(0, len(times)-1, n_plots)]

L__X_fig, L__X_axs = plt.subplots(n_rows, n_cols, figsize=(fs[0], n_cols*fs[1]))
for i_plot in range(n_plots):
    row, col = np.divmod(i_plot, n_cols)
    ax = L__X_axs[row][col]
    cs = ax.imshow(data[x_track][:,t_idxs[i_plot]].reshape(No2,No2))
    ax.set_title('t='+str(times[t_idxs[i_plot]]))
    
plt.show()

In [None]:
L_out_fig = plt.figure(figsize=(8,8))
axs = plt.subplot(1,1,1)
cs = axs.imshow(net.units[V[0]].L_out_copy.reshape(10,10))
L_out_fig.colorbar(cs)
plt.show()

In [None]:
net.units[V[0]].L_out_copy.sum()

In [None]:
# Storing the M__C and A__M connections
M__C_mat = np.zeros((2,2)) # rows are target (C) neurons
for m_idx, m_id in enumerate(M):
    for c_idx, c_id in enumerate(C):
        syn_list = net.syns[c_id]
        for syn in syn_list:
            if syn.preID == m_id:
                M__C_mat[c_idx, m_idx] = syn.w
                break
                
A__M_mat = np.zeros((2,4)) # rows are target (M) neurons
for a_idx, a_id in enumerate(A):
    for m_idx, m_id in enumerate(M):
        syn_list = net.syns[m_id]
        for syn in syn_list:
            if syn.preID == a_id:
                A__M_mat[m_idx, a_idx] = syn.w
                break
print(M__C_mat)
print(A__M_mat)

In [None]:
plotty = plotter(net, times, data)

In [None]:
plotty.act_anim(S1+S2+L+R+V, 0.5, interv=10, slider=False)

In [None]:
plotty.conn_anim(L, V)

In [None]:
# A figure with more formatting
SPF_fig = plt.figure(figsize=(20,10))
SF_data = np.array(data[SF])
SP_data = np.array(data[SP])
plt.plot(times, SF_data.transpose(), label='$S_P$', linewidth=2)
plt.plot(times, SP_data.transpose(), label='$S_D$', linewidth=4)
plt.xticks(fontsize=25)
plt.yticks(fontsize=25)
plt.legend(fontsize=25)
plt.xlabel('time (s)', fontsize =25)
plt.title('$S_D, S_P$', fontsize=30)
plt.show()

In [None]:
# good initial weights

# M__C
# M[0] has the F-D error, so you want C[0] (which exerts positive torque)
# to be driven by M[1] instead, and C[1] to be driven by M[0]
shift = 0
while not net.syns[C[0]][shift].type is synapse_types.rga:
    shift += 1
net.syns[C[0]][shift].w = 0.1
net.syns[C[0]][shift+1].w = 0.8 # C[0] driven by M[1]
net.syns[C[1]][shift].w = 0.8
net.syns[C[1]][shift+1].w = 0.1

# L__X, L__V
scale = 1. # maximum weight value
def dist(c1, c2):
    """ Periodic distance between 2-dim coordinates c1 and c2. 

        Assumes c1 and c2 are inside the box with corners [-0.5, -0.5], [0.5, 0.5].
    """
    x_dist = min(max(c1[0], c2[0]) - min(c1[0], c2[0]),  # "inner" distance
                 0.5-max(c1[0], c2[0]) + (min(c1[0], c2[0])+0.5) ) # "outer" distance
    y_dist = min(max(c1[1], c2[1]) - min(c1[1], c2[1]),  # "inner" distance
                 0.5-max(c1[1], c2[1]) + (min(c1[1], c2[1])+0.5) ) # "outer" distance
    return np.sqrt(x_dist*x_dist + y_dist*y_dist)

#L__V_iw = np.zeros((No2, No2))
j = 0
for i in range(len(L)):  # Setting L__V weights
    u = net.units[L[i]]
    c = u.coordinates
    d = dist(c, [c[1], c[1]])
    Vsyn = net.syns[V[0]][i]
    Xsyn = net.syns[X[0]][i+j]
    while Xsyn.type != synapse_types.diff_rm_hebbian:
        j +=1
        Xsyn = net.syns[X[0]][i+j]
    if Vsyn.preID == u.ID and Xsyn.preID == u.ID:
        a, b = divmod(i, No2)
        Vsyn.w = scale*(.5 - d)
        Xsyn.w = 0.01 if abs(c[1]-c[0]) < 0.5-max(c[1],c[0]) + min(c[1],c[0])+0.5 else .8
    else:
        print("FAILED!!!!!!!")

In [None]:
#===================================================================
#================ CREATE THE NETWORK ===============================
#===================================================================
net, pops_dict = rl5E_net(cfg, 
                          pres_interv=5.,
                          rand_w=False,
                          par_heter=0.1,
                          x_switch=True,
                          V_normalize=True,
                          X_normalize=True)

for name in pops_dict:
    exec(name + '=' + str(pops_dict[name]))
des_angs = np.array(des_angs)

In [None]:
# plot all factors in the M--C0 synaptic plasticity
fs = (20,6)
plastic_fig = plt.figure(figsize=fs)
xp_data = np.array(data[xp_track[0]])
up_data = np.array(data[up_track[0]])
sp_data = np.array(data[sp_track[0]])
spj_data = np.array(data[spj_track[0]])
plt.plot(times, xp_data)
plt.plot(times, up_data)
plt.plot(times, sp_data)
plt.plot(times, spj_data)
plt.legend(['xp', 'up', 'sp', 'spj'])

plastic_fig2 = plt.figure(figsize=fs)
f1 = up_data - xp_data
f2 = sp_data - spj_data
rule = 500. * f1 * f2
plt.plot(times, f1)
plt.plot(times, f2)
plt.plot(times, rule)
plt.plot(times, np.zeros(len(times)), 'k', linewidth=1)
plt.legend(['up - xp', 'sp - spj', 'prod'])
plt.show()

In [None]:
# checking some connections
print("Connections to M0 unit")
for idx, syn in enumerate(net.syns[M[0]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id == P and hasattr(syn, 'plant_out'):
        pre_pop = 'P'
    elif pre_id in A:
        pre_pop = 'A'
    #elif pre_id in L:
    #    pre_pop = 'L'
    elif pre_id in M:
        pre_pop = 'M'
    #elif pre_id in V:
    #    pre_pop = 'V'
    elif pre_id in SPF:
        pre_pop = 'SPF'
    else:
        pre_pop = 'erroneous'
    if pre_pop == 'P':
        plant_out = str(syn.plant_out)
    else:
        plant_out = 'None'
    print('%d) %s (%d) --> M0, w=%f, port=%d, plant_out=%s'%(idx, pre_pop, pre_id, syn.w, syn.port, plant_out))

print("Connections to M1 unit")
for idx, syn in enumerate(net.syns[M[1]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id == P and hasattr(syn, 'plant_out'):
        pre_pop = 'P'
    elif pre_id in A:
        pre_pop = 'A'
    #elif pre_id in L:
    #    pre_pop = 'L'
    elif pre_id in M:
        pre_pop = 'M'
    #elif pre_id in V:
    #    pre_pop = 'V'
    elif pre_id in SPF:
        pre_pop = 'SPF'
    else:
        pre_pop = 'erroneous'
    if pre_pop == 'P':
        plant_out = str(syn.plant_out)
    else:
        plant_out = 'None'
    print('%d) %s (%d) --> M1, w=%f, port=%d, plant_out=%s'%(idx, pre_pop, pre_id, syn.w, syn.port, plant_out))
    
print("Connections to C0 unit")
for idx, syn in enumerate(net.syns[C[0]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id == P and hasattr(syn, 'plant_out'):
        pre_pop = 'P'
    elif pre_id in A:
        pre_pop = 'A'
    elif pre_id in M:
        pre_pop = 'M'
    elif pre_id in C:
        pre_pop = 'C'
    else:
        pre_pop = 'erroneous'
    if pre_pop == 'P':
        plant_out = str(syn.plant_out)
    else:
        plant_out = 'None'
    print('%d) %s (%d) --> C0, w=%f, port=%d, plant_out=%s'%(idx, pre_pop, pre_id, syn.w, syn.port, plant_out))  

print("Connections to C1 unit")
for idx, syn in enumerate(net.syns[C[1]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id == P and hasattr(syn, 'plant_out'):
        pre_pop = 'P'
    elif pre_id in A:
        pre_pop = 'A'
    elif pre_id in M:
        pre_pop = 'M'
    elif pre_id in C:
        pre_pop = 'C'
    else:
        pre_pop = 'erroneous'
    if pre_pop == 'P':
        plant_out = str(syn.plant_out)
    else:
        plant_out = 'None'
    print('%d) %s (%d) --> C1, w=%f, port=%d, plant_out=%s'%(idx, pre_pop, pre_id, syn.w, syn.port, plant_out))  

    
print("Connections to afferent units")
for idx, syn in enumerate(net.syns[A[2]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id == P and hasattr(syn, 'plant_out'):
        pre_pop = 'P'
    elif pre_id in A:
        pre_pop = 'A'
    else:
        pre_pop = 'erroneous'
    if pre_pop == 'P':
        plant_out = str(syn.plant_out)
    else:
        plant_out = 'None'
    print('%d) %s (%d) --> A, w=%f, port=%d, plant_out=%s'%(idx, pre_pop, pre_id, syn.w, syn.port, plant_out))

print("Connections to plant")
for idx, syn in enumerate(net.plants[P].inp_syns[0]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id in C:
        pre_pop = 'C'
    else:
        pre_pop = 'erroneous'
    print('%d) %s (%d, %s) --> P, w=%f'%(idx, pre_pop, pre_id, pre_type, syn.w))
    
print("Connections to MPLEX units")
for idx, syn in enumerate(net.syns[MPLEX[0]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id in X:
        pre_pop = 'X'
    elif pre_id in SF:
        pre_pop = 'SF'
    elif pre_id in SP:
        pre_pop = 'SP'
    else:
        pre_pop = 'erroneous'
    print('%d) %s (%d) --> MPLEX, w=%f, port=%d'%(idx, pre_pop, pre_id, syn.w, syn.port))  

print("Connections to SF units")
for idx, syn in enumerate(net.syns[SF[1]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id == P and hasattr(syn, 'plant_out'):
        pre_pop = 'P'
    elif pre_id in T:
        pre_pop = 'T'
    else:
        pre_pop = 'erroneous'
    print('%d) %s (%d) --> SF, w=%f, port=%d, plant_out=%s'%
          (idx, pre_pop, pre_id, syn.w, syn.port, plant_out))

print("Connections to X unit")
for idx, syn in enumerate(net.syns[X[0]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id in V:
        pre_pop = 'V'
    elif pre_id in MPLEX:
        pre_pop = 'MPLEX'
    else:
        pre_pop = 'erroneous'
    print('%d) %s (%d) --> X, w=%f, port=%d'%(idx, pre_pop, pre_id, syn.w, syn.port))  
    
print("Connections to V unit")
for idx, syn in enumerate(net.syns[V[0]]):
    pre_id = syn.preID
    pre_type = net.units[pre_id].type
    if pre_id in MPLEX:
        pre_pop = 'MPLEX'
    elif pre_id in R:
        pre_pop = 'R'
    else:
        pre_pop = 'erroneous'
    print('%d) %s (%d) --> V, w=%f, port=%d'%(idx, pre_pop, pre_id, syn.w, syn.port))  


---