In [3]:
!python --version

Python 3.11.0


In [None]:
import brian2 as b2
from brian2 import *
import matplotlib.pyplot as plt



# TODO: Make weights always non-negative!

# Parameters for neurons
Cm = 281 * pF  # Membrane capacitance
gL = 30 * nS   # Leak conductance
EL = -70.6 * mV  # Leak reversal potential
Vth = -50.4 * mV  # Spike threshold
DeltaT = 2 * mV  # Slope factor
Vr = -70.6 * mV  # Reset potential
Vcut = -40 * mV  # Cutoff potential for spike generation
tau_A = 1 * ms  # Adaptation time constant
c = 4 * nS       # Coupling parameter
b = 0.0805 * nA  # Spike-triggered adaptation increment

# Parameters for synapses
tau_rise = 5 * ms  # Rise time constant for AMPA
tau_decay = 50 * ms  # Decay time constant for NMDA
w_init = 0.5 * nS   # Initial synaptic weight (conductance)

# AdEx neuron equations
eqs = '''
dv/dt = ( -gL * (v - EL) + gL * DeltaT * exp((v - Vth) / DeltaT) + I - A ) / Cm : volt
dA/dt = (c * (v - EL) - A) / tau_A : amp

I : amp  # Input current
'''

# Synapse equations (Spike trace dynamics)
syn_eqs_exc = '''
dX/dt = -X / tau_decay + Y / tau_rise : volt   # Spike trace
dY/dt = -Y / tau_decay : volt  # Glutamate decay
w : siemens  # Synaptic weight (conductance)
I_post = w * X : amp (summed)
'''
# Isyn_exc = w * X * (v_post - EL) : amp 
syn_eqs_inh = '''
dX/dt = -X / tau_decay + Y / tau_rise : volt  # Spike trace
dY/dt = -Y / tau_decay : volt  # Glutamate decay
w : siemens  # Synaptic weight (conductance)
I_post = w * X : amp (summed)
'''





# dims = [784,400,225,64]
dims = [50,40,30,20]

gist_dim = 16



def make_groups(dims):

  groups = []

  for dim in dims:

     groups.append(NeuronGroup(dim, eqs, threshold='v > Vcut', reset='v = Vr', method='euler'))

  for group in groups:
     group.v = EL


  return groups

Rs = make_groups(dims)
Es_0 = make_groups(dims)
Es_1 = make_groups(dims)

def make_bottom_up_connections(Rs, Es_0, Es_1, syn_eqs_exc, syn_eqs_inh):


  S_p = Synapses(Rs, Es_0, model=syn_eqs_exc,
        on_pre='Y = 1')
  S_p.connect(condition='i == j')  # One-to-one connections
  S_p.w = 'w_init'


  S_m = Synapses(Rs, Es_1, model=syn_eqs_inh,
        on_pre='Y = 1*volt')
  S_m.connect(condition='i == j')  # One-to-one connections
  S_m.w = 'w_init'

  return S_p, S_m

S_p, S_m = make_bottom_up_connections(Rs[0], Es_0[0], Es_1[0], syn_eqs_exc, syn_eqs_inh)

def make_top_down_connections(Rs, Es_0, Es_1, syn_eqs_exc, syn_eqs_inh):


  S_p = Synapses(Rs, Es_1, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
  S_p.connect()
  S_p.w = 'rand() * w_init'


  S_m = Synapses(Rs, Es_0, model=syn_eqs_inh,
        on_pre='Y = 1*volt')
  S_m.connect()
  S_m.w = 'rand() * w_init'


  return S_p, S_m

def make_gist_connections(Rs, Gs):

  S = Synapses(Rs[0], G, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
  S.connect(p=0.05)  # Connect with 5% probability
  S.w = 'rand() * w_init'   # This is slightly different from the paper: it should be based on a ratio

  S_gist_input = S

  for R in Rs[1:]:
    S = Synapses(G, R, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
    S.connect(p=0.05)  # Connect with 5% probability
    S_gist_output = S

  return S_gist_input, S_gist_output

# num_classes_per_layer and max_depth depend on the same parameters used when generating the datset

def make_output_layer(num_classes_per_layer, max_depth, w_out_init = None ):

    if w_out_init == None:
        w_out_init = w_init/10
        
        Os = []
    for d in range(0, max_depth):
        Os.append(NeuronGroup(N=num_classes_per_layer**(d), model = syn_eqs_exc, threshold='v > Vcut', reset='v = Vr', method='euler'))
        
    for d in range(1, max_depth):
        if d>1:
            S_o = Synapses(Os[d-1], Os[d], model=syn_eqs_exc,
            on_pre='Y = 1*volt')
            S_o.connect(condition='i < j/num_classes_per_layer')  # Hierarchical connections (hopefully)
            S_o.w = 'w_out_init'

    return Os, S_o






def make_network(dims, syn_eqs_exc, syn_eqs_inh, num_classes_per_layer, max_depth, w_out_init = None ):

  Rs = make_groups(dims)
  Es_0 = make_groups(dims)
  Es_1 = make_groups(dims)
  G = NeuronGroup(gist_dim, eqs, threshold='v > Vcut', reset='v = Vr', method='euler')

  Os, S_o_internal = make_output_layer(num_classes_per_layer, max_depth) # Hopefully this is correct

  connections = {}
  connections["output_internal"] = S_o_internal

  for i in range(len(Rs)):


    S_p, S_m = make_bottom_up_connections(Rs[i], Es_0[i], Es_1[i], syn_eqs_exc, syn_eqs_inh)

    connections[f"bottom_up_{i}"] = [S_p, S_m]

    if i != len(Rs)-1:

      S_p, S_m = make_top_down_connections(Rs[i+1], Es_0[i], Es_1[i], syn_eqs_exc, syn_eqs_inh)

      connections[f"top_down_{i}"] = [S_p, S_m]

  S_gist_input, S_gist_output = make_gist_connections(Rs, G)

  connections["gist_input"] = S_gist_input
  connections["gist_output"] = S_gist_output


  S_o_external = Synapses(Rs[-1], Os, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
  S_o_external.connect()
  S_o_external.w = 'rand() * w_init'




  return Rs, Es_0, Es_1, G, Os, connections




connections = make_network(dims, syn_eqs_exc, syn_eqs_inh, num_classes_per_layer=3, max_depth=3)















In [None]:
def training_phase_1(network, dataset):

  for curve, label in dataset:
    ...