In [1]:
import torch
import brian2 as b2
from brian2 import *
from itertools import chain
from functools import partial
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

from dataset_util import Synthetic_Dataset_Utils


In [2]:
# TODO: Make weights always non-negative!

# Parameters for neurons
Cm = 281 * pF  # Membrane capacitance
gL = 30 * nS   # Leak conductance
EL = -70.6 * mV  # Leak reversal potential
Vth = -50.4 * mV  # Spike threshold
DeltaT = 2 * mV  # Slope factor
Vr = -70.6 * mV  # Reset potential
Vcut = -40 * mV  # Cutoff potential for spike generation
tau_A = 1 * ms  # Adaptation time constant
c = 4 * nS       # Coupling parameter
b = 0.0805 * nA  # Spike-triggered adaptation increment

# Parameters for synapses
tau_rise = 5 * ms  # Rise time constant for AMPA
tau_decay = 50 * ms  # Decay time constant for NMDA
w_init = 0.5 * nS   # Initial synaptic weight (conductance)

# AdEx neuron equations
eqs = '''
dv/dt = ( -gL * (v - EL) + gL * DeltaT * exp((v - Vth) / DeltaT) + I - A ) / Cm : volt
dA/dt = (c * (v - EL) - A) / tau_A : amp
I : amp  # Input current
'''

input_eqs = '''
dv/dt = ( -gL * (v - EL) + gL * DeltaT * exp((v - Vth) / DeltaT) + I - A ) / Cm : volt
dA/dt = (c * (v - EL) - A) / tau_A : amp
I = input_stimuli(t)  : amp  
'''

out_eqs = '''
dv/dt = ( -gL * (v - EL) + gL * DeltaT * exp((v - Vth) / DeltaT) + I_tot - A ) / Cm : volt
dA/dt = (c * (v - EL) - A) / tau_A : amp
I : amp  # Synaptic input (summed)
I_tot = I + output_stimuli(t, indices) : amp
indices : integer  # dimensionless index variable
'''

# Synapse equations (Spike trace dynamics)
syn_eqs_exc = '''
dX/dt = -X / tau_decay + Y / tau_rise : volt   # Spike trace
dY/dt = -Y / tau_decay : volt  # Glutamate decay
w : siemens  # Synaptic weight (conductance)
I_post = w * X : amp (summed)
'''
# Isyn_exc = w * X * (v_post - EL) : amp 
syn_eqs_inh = '''
dX/dt = -X / tau_decay + Y / tau_rise : volt  # Spike trace
dY/dt = -Y / tau_decay : volt  # Glutamate decay
w : siemens  # Synaptic weight (conductance)
I_post = w * X : amp (summed)
'''




In [3]:
# dims = [784,400,225,64]
dims = [50,40,30,20]

gist_dim = 16



def make_groups(dims, eqs = eqs):

  groups = []

  for dim in dims:

     groups.append(NeuronGroup(dim, eqs, threshold='v > Vcut', reset='v = Vr', method='euler'))

  for group in groups:
     group.v = EL


  return groups

In [4]:
def make_bottom_up_connections(Rs, Es_0, Es_1, syn_eqs_exc, syn_eqs_inh):


  S_p = Synapses(Rs, Es_0, model=syn_eqs_exc,
        on_pre='Y = 1')
  S_p.connect(condition='i == j')  # One-to-one connections
  S_p.w = 'w_init'


  S_m = Synapses(Rs, Es_1, model=syn_eqs_inh,
        on_pre='Y = 1*volt')
  S_m.connect(condition='i == j')  # One-to-one connections
  S_m.w = 'w_init'

  return S_p, S_m

In [5]:
def make_top_down_connections(Rs, Es_0, Es_1, syn_eqs_exc, syn_eqs_inh):


  S_p = Synapses(Rs, Es_1, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
  S_p.connect()
  S_p.w = 'rand() * w_init'


  S_m = Synapses(Rs, Es_0, model=syn_eqs_inh,
        on_pre='Y = 1*volt')
  S_m.connect()
  S_m.w = 'rand() * w_init'


  return S_p, S_m

In [6]:
def make_gist_connections(Rs, G):

  S = Synapses(Rs[0], G, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
  S.connect(p=0.05)  # Connect with 5% probability
  S.w = 'rand() * w_init'   # This is slightly different from the paper: it should be based on a ratio

  S_gist_input = S

  for R in Rs[1:]:
    S = Synapses(G, R, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
    S.connect(p=0.05)  # Connect with 5% probability
    S_gist_output = S

  return S_gist_input, S_gist_output

In [7]:
# num_classes_per_layer and max_depth depend on the same parameters used when generating the datset

def make_output_layer(num_classes_per_layer, max_depth, w_out_init = None ):

    if w_out_init == None:
        w_out_init = w_init/10
        
        

    total_neurons = sum([num_classes_per_layer**(d+1) for d in range(max_depth)])
    
    
    Os = NeuronGroup(N=total_neurons, model = out_eqs, threshold='v > Vcut', reset='v = Vr', method='euler')

    

    S_o = Synapses(Os, Os, model=syn_eqs_exc, on_pre='Y = 1*volt')
    
    trace = 0
    for d in range(num_classes_per_layer - 1):

        new_trace = trace + num_classes_per_layer**(d+1)
        
        source_indices = list(range(trace, trace + num_classes_per_layer**(d+1)))
        target_indices = list(range(new_trace , new_trace + num_classes_per_layer**(d+2)))

        ii, jj = np.meshgrid(source_indices, target_indices, indexing='ij')
        
        S_o.connect(i=ii.flatten(), j=jj.flatten())

        trace = new_trace


    stimulus_indices = []

    for i, n in enumerate([num_classes_per_layer**i for i in range(1, max_depth+1)]):
        for _ in range(n):
            stimulus_indices.append(int(i))
            
    Os.indices = stimulus_indices 


    return Os, S_o





In [8]:

def make_network(dims, syn_eqs_exc, syn_eqs_inh, num_classes_per_layer, max_depth, w_out_init = None ):

    max_depth += 1
    
    Rs = make_groups(dims)
    Es_0 = make_groups(dims)
    Es_1 = make_groups(dims)
    G = NeuronGroup(gist_dim, eqs, threshold='v > Vcut', reset='v = Vr', method='euler')
    
    Os, S_o_internal = make_output_layer(num_classes_per_layer, max_depth) # Hopefully this is correct
    
    connections = {}
    connections["output_internal"] = S_o_internal
    
    for i in range(len(Rs)):
    
    
        S_p, S_m = make_bottom_up_connections(Rs[i], Es_0[i], Es_1[i], syn_eqs_exc, syn_eqs_inh)
        
        connections[f"bottom_up_{i}"] = [S_p, S_m]
    
    if i != len(Rs)-1:
    
        S_p, S_m = make_top_down_connections(Rs[i+1], Es_0[i], Es_1[i], syn_eqs_exc, syn_eqs_inh)
        
        connections[f"top_down_{i}"] = [S_p, S_m]
    
    S_gist_input, S_gist_output = make_gist_connections(Rs, G)
    
    connections["gist_input"] = S_gist_input
    connections["gist_output"] = S_gist_output
    
    '''
    S_o_external = Synapses(Rs[-1], Os, model=syn_eqs_exc,
        on_pre='Y = 1*volt')
    S_o_external.connect()
    S_o_external.w = 'rand() * w_init'
    '''
    S_o_external = []
    for O in Os:
        S = Synapses(Rs[-1], O, model=syn_eqs_exc, on_pre='Y = 1*volt')
        S.connect()
        S.w = 'rand() * w_init'
        S_o_external.append(S)
    connections["output_external"] = S_o_external


    return Rs, Es_0, Es_1, G, Os, connections




# Helpful for setting up input

In [9]:
def normalize_tensor(tensor, old_min, old_max, new_min, new_max):

    normalized_tensor = (tensor - old_min) / (old_max - old_min)


    scaled_tensor = normalized_tensor * (new_max - new_min) + new_min

    return scaled_tensor



In [10]:
def normalize_and_unwrap_dataset(dt, minval,maxval):


    minimum = 0
    maximum = 0
    
    all_curves = [torch.tensor(curve, dtype = torch.float64) for curve in chain.from_iterable(dt["curves"].values())]
    
    for curve in all_curves:
    
        _min = torch.min(curve)
        _max = torch.max(curve)
        if(minimum > _min): minimum = _min
        if(maximum < _max): maximum = _max


    unwrapped = []
    for key in dt["categories"].keys():
        for curve in dt["curves"][key]:
            unwrapped.append((torch.tensor(key), normalize_tensor(torch.tensor(curve), minimum, maximum, minval, maxval)))
            
    return unwrapped
            

In [11]:
def get_output_stimuli_indexes_from_labels(labels, num_classes_per_layer):

    if labels.ndim == 1:
        labels = labels.unsqueeze(0)  # Convert to 2D with shape (1, len(I_indexes))

    indexes = torch.zeros_like(labels)
    
    for i in range(0, labels.shape[1]):
    
        summed = 0
        if i>0:
            summed = torch.stack([labels[:,j]*num_classes_per_layer**(i-j) for j in range(i)]).sum(dim=0)
    
        indexes[:,i] = summed + labels[:,i] # These are more intelligible, since they indicate which neuron to stimulate for each level of granularity

    # But we need to adapt them for a situation where all neurons are concatenated in a single list
    
    to_sum = 0
    for i in range(0, labels.shape[1]):

        if i>0:
            to_sum += num_classes_per_layer**i 

        indexes[:,i] = indexes[:,i] + to_sum

    return indexes



def get_output_current_arrays(I_indexes, dim, I_value):
    # Ensure I_indexes is 2D for consistent processing
    if I_indexes.ndim == 1:
        I_indexes = I_indexes.unsqueeze(0)  # Convert to 2D with shape (1, len(I_indexes))

    # Create an output tensor of zeros
    out = torch.zeros((I_indexes.shape[0], dim), dtype=torch.float32)

    # Row indices (batch indices) for advanced indexing
    row_indices = torch.arange(I_indexes.shape[0]).repeat_interleave(I_indexes.shape[1])

    # Flattened column indices (curve indexes)
    col_indices = I_indexes.flatten()

    # Set the values using advanced indexing
    out[row_indices, col_indices] = I_value

    return out


class CurveDataset(Dataset):
    def __init__(self, data, minval, maxval):
        
        self.data = normalize_and_unwrap_dataset(data, minval, maxval)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        labels, curves = self.data[idx]
      
  
        return labels, curves.flatten()



def collate_fn(batch, msec_step, num_pause_blocks, num_classes_per_layer, out_dim):

    

    labels, curves = zip(*batch)

    labels = torch.stack(labels, dim=0)
    curves = torch.stack(curves, dim=0)


   
    pause = torch.zeros_like(curves[0])  
    
    pause_block = torch.tile(pause, (num_pause_blocks, 1))  
    
    # Interleave stimulus rows with the pause block
    curves_with_pause = torch.vstack([torch.vstack((row, pause_block)) for row in curves])

    visual_stimulus = TimedArray(curves_with_pause.numpy() * nA, dt=msec_step*ms)




    out_stimuli_idx = get_output_stimuli_indexes_from_labels(labels, num_classes_per_layer)
   
    out_stimuli_array = get_output_current_arrays(out_stimuli_idx, out_dim, 1)

    label_pause = torch.zeros(out_dim)  
    
    label_pause_block = torch.tile(label_pause, (num_pause_blocks, 1)) 

    out_stimuli_with_pause = torch.vstack([torch.vstack((row, label_pause_block)) for row in out_stimuli_array])

    output_stimuli = TimedArray(out_stimuli_with_pause.numpy() * nA, dt=msec_step*ms) 
    

    return output_stimuli , visual_stimulus

In [12]:
dims = [15**2, 50,40,30]
Rs, Es_0, Es_1, G, Os, connections = make_network(dims, syn_eqs_exc, syn_eqs_inh, num_classes_per_layer=3, max_depth=3)

INFO       Cannot use compiled code, falling back to the numpy code generation target. Note that this will likely be slower than using compiled code. Set the code generation to numpy manually to avoid this message:
prefs.codegen.target = "numpy" [brian2.devices.device.codegen_fallback]


_cython_magic_32a1cdc0f6fb3e035c7e07036edf98c2.cpp
   Creazione della libreria C:\Users\bruno\.cython\brian_extensions\Users\bruno\.cython\brian_extensions\_cython_magic_32a1cdc0f6fb3e035c7e07036edf98c2.cp311-win_amd64.lib e dell'oggetto C:\Users\bruno\.cython\brian_extensions\Users\bruno\.cython\brian_extensions\_cython_magic_32a1cdc0f6fb3e035c7e07036edf98c2.cp311-win_amd64.exp
Generazione codice in corso...
Generazione codice terminata
LINK : fatal error LNK1158: impossibile eseguire 'rc.exe'


In [13]:
su = Synthetic_Dataset_Utils()

ranges = [30,30,30,30]  # Ranges for each parameter

prior_params = [10,10,10,10]  # Initial parameters

num_samples_per_class=5
N=15

# Build the tree
max_depth = 3 # Adjust as needed
num_classes_per_layer = 3  # Adjust as needed

std_multiplier = 1

tree = su.build_tree(prior_params, 0, max_depth, num_classes_per_layer, std_multiplier, ranges)

synth_dataset = su.make_dataset(tree, num_samples_per_class=num_samples_per_class, N=N)

In [14]:
output_dim = sum([len(o) for o in Os])


batch_size = 20
num_pause_blocks = 1

dataset = CurveDataset(synth_dataset, 0.6, 1.5)
dl = DataLoader( dataset, shuffle = True, batch_size = batch_size, collate_fn = partial(collate_fn, msec_step=100, num_pause_blocks=num_pause_blocks,
                                                                                num_classes_per_layer = num_classes_per_layer, out_dim = output_dim) )

In [15]:
for output_stimuli, input_stimuli in dl:

    run((batch_size + batch_size*num_pause_blocks)*100 * ms)

In [None]:
def training_phase_1(network, dl):

  for label_input, curve_input in dl:
    ...

## Other

In [None]:
eqs_neuron = '''
dv/dt = (-gL * (v - EL) + gL * DeltaT * exp((v - Vth) / DeltaT) + I - A) / Cm : volt
dA/dt = (c * (v - EL) - A) / tau_A : amp
I = stimulus(t)  : amp  
'''


neuron = make_groups([1], eqs = eqs_neuron)[0]

neuron

# Change 2 to any value from 0.6 to 1.5, which is the range of current intensity we are probably going to use

stimulus = TimedArray(np.hstack([[c] for c in np.ones(1)*0.8]) * nA, dt=10*ms)



neuron.v = EL
neuron.A = 0 * nA

# Monitor the specific neuron (neuron 2)
monitor = StateMonitor(neuron, ['v', 'I'], record=[0])
spike_monitor = SpikeMonitor(neuron)

run(1 * second)

# Plot membrane potential
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))

# Membrane potential plot
plt.subplot(2, 1, 1)
plt.plot(monitor.t / ms, monitor.v[0] / mV, label='Neuron 2 Membrane Potential')
plt.xlabel('Time (ms)')
plt.ylabel('Membrane Potential (mV)')
plt.title('Stimulus Applied to Single Neuron')
plt.legend()

# Firing rate plot
plt.subplot(2, 1, 2)
plt.hist(spike_monitor.t / ms, bins=10, alpha=0.7)
plt.xlabel('Time (ms)')
plt.ylabel('Spike Count')
plt.title('Firing Rate of Neuron 2')

plt.tight_layout()
plt.show()



'''
0.6: 0
0.625: 0
0.63: 0.5
0.65: 1.5
0.7: 2.8
0.8: 4.5
1: 7.5
1.5: 14
2: 20
'''

In [None]:
# Let's connect 2 groups as a trial

G1 = NeuronGroup(5, 'v : volt', threshold='v > Vcut', reset='v = Vr', method='euler')  # Presynaptic neurons
G2 = NeuronGroup(5, eqs, threshold='v > Vcut', reset='v = Vr; A += b', method='euler')  # Postsynaptic neurons

# Initialize variables
G1.v = EL
G2.v = EL
G2.A = 0 * nA

# Create synapses
S = Synapses(G1, G2, model=syn_eqs_exc,
             on_pre='Y = 1*volt', method = "euler")  # Increment glutamate release on spike
S.connect(p=0.1)  # Random connections
S.w = 'rand() * w_init'  # Random initial weights

# Monitors
spike_mon_G1 = SpikeMonitor(G1)
spike_mon_G2 = SpikeMonitor(G2)
state_mon_G2 = StateMonitor(G2, ['v', 'I', 'A'], record=True)

# Run simulation
b2.run(500 * ms)

plt.figure(figsize=(12, 6))

# Plot membrane potential of a postsynaptic neuron
plt.subplot(311)
plt.plot(state_mon_G2.t / ms, state_mon_G2.v[0] / mV, label='Membrane potential (v)')
plt.xlabel('Time (ms)')
plt.ylabel('Voltage (mV)')
plt.legend()


plt.tight_layout()
plt.show()