In [1]:
from brian2 import *
import torch
import numpy as np

from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

In [2]:
raw_train_seqs = torch.load('raw_train_seqs.pt')
raw_test_seqs = torch.load('raw_test_seqs.pt')

In [3]:
def rate_encode(tensor, d):

    n, features = tensor.shape
    assert features == 5, "The tensor must have 5 features per data point"
    assert d % features == 0, "d must be divisible by 5."

    max_firing_rate= d/5
    print(max_firing_rate)
    time_frame = d // features

    normalized_data = (tensor - np.min(tensor)) / (np.max(tensor) - np.min(tensor))

    spike_trains = np.zeros((n, d))
    for i in range(n):
        encoded_data_point = np.array([])
        for j in range(features):
            # count spikes
            spike_count = np.round(normalized_data[i, j] * max_firing_rate).astype(int)
            # doesn't exceed time frame
            spike_count = min(spike_count, time_frame)
            # generate spike train
            feature_train = np.zeros(time_frame)
            if spike_count > 0:
                spike_times = np.random.choice(time_frame, spike_count, replace=False)
                feature_train[spike_times] = 1
            
            encoded_data_point = np.concatenate((encoded_data_point, feature_train))
        
        spike_trains[i] = encoded_data_point

    return spike_trains

In [4]:
# convert spike trains to spike times
def convert_to_spike_times(spike_train, dt):
    spike_indices = np.where(spike_train == 1)[0]
    times = spike_indices * dt
    # set indices to 0 as each SpikeGeneratorGroup has only one neuron
    indices = np.zeros_like(spike_indices)
    return indices, times

In [5]:
def recorded_spikes_to_dataset(recorded_spikes, num_selected_neurons):
    n = num_selected_neurons
    indices = []
    for i in range(n):
        step = 50 // n
        indices.append(step * i)
    
    num_data_points = len(recorded_spikes)
    dataset = np.zeros((num_data_points, n))
    
    for i in range(num_data_points):
        spike_data = recorded_spikes[i]
        for idx in range(num_selected_neurons): # access the neurons 
            j = indices[idx]
            spikes_j = spike_data[j]
            dataset[i, idx] = len(spikes_j)
    
    return dataset

In [6]:
d = 100 
# define the length of spike train per data point
# d/5 will be the length of spike trains per input neuron

train_spikes = rate_encode(raw_train_seqs.numpy(), d)
test_spikes = rate_encode(raw_test_seqs.numpy(), d)

20.0
20.0


In [7]:
# Simulation settings
start_scope()
num_neurons = 50
duration = 1 * second

net = Network()

# Neuron model
eqs = '''
dv/dt = (I-v)/tau : 1
I : 1
tau : second
'''

# reservoir
G = NeuronGroup(num_neurons, eqs, threshold='v>1', reset='v = 0', method='exact')
G.v = 'rand()'
G.I = 'rand()'
G.tau = '2.5*ms'
net.add(G)

# STDP model
tau_pre = tau_post = 5*ms
wmax = 0.5
A_pre = 0.01
A_post = -A_pre * 1.05

stdp_eqs = '''
w : 1
dapre/dt = -apre / tau_pre : 1 (event-driven)
dapost/dt = -apost / tau_post : 1 (event-driven)
'''

# Synapses
p_connect = 0.1  # prob of connection
S = Synapses(G, G, model=stdp_eqs, on_pre='v_post += w; apre += A_pre; w = clip(w + apost, 0, wmax)',
             on_post='apost += A_post; w = clip(w + apre, 0, wmax)')


S.connect(p=p_connect)
S.w = 'rand() * wmax'  # weight init
net.add(S)  

# Monitor
M = SpikeMonitor(G)
net.add(M)

# set up input neurons
num_input_neurons = 5
input_neurons = [SpikeGeneratorGroup(1, [0], [0*ms], dt=None) for _ in range(num_input_neurons)]

# connect input neurons to the network
for inp_neuron in input_neurons:
    S = Synapses(inp_neuron, G, on_pre='v += 0.5')
    S.connect(p=0.5)
    net.add(inp_neuron)
    net.add(S)

In [8]:
segment_length = len(train_spikes[0]) // 5 
print(segment_length)
duration_of_data_point = 50*ms

recorded_spikes = []

idx = 0

for data_point in train_spikes:

    for i in range(num_input_neurons):
        start, end = i * segment_length, (i + 1) * segment_length
        segment = data_point[start:end]
        indices, times = convert_to_spike_times(segment, 5*ms)
#         print(times)
        
        input_neurons[i].set_spikes(indices, times)
    
    net.run(duration_of_data_point)
    net.t_ = 0*ms
    recorded_spikes.append(M.spike_trains())

    # clear and reinit
    net.remove(M) 
    M = SpikeMonitor(G) 
    net.add(M)
    
    G.v = '0'

    for inp_neuron in input_neurons:
        inp_neuron.set_spikes([0], [0*ms])
        
    if idx % 10 == 0:
        print(idx)
    idx += 1


20




0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590


In [9]:
print(len(recorded_spikes))
n_out = 10
x_train_rs = recorded_spikes
x_train_snn = recorded_spikes_to_dataset(recorded_spikes, n_out)
x_train_snn = torch.from_numpy(x_train_snn)
torch.save(x_train_snn, 'x_train_snn.pt')

X_train_snn = torch.load('x_train_snn.pt')

597


In [10]:
# print to see spike counts of each neuron for each data point
idx = 50
spike_data = recorded_spikes[idx]

for neuron_index in spike_data:
    spikes = spike_data[neuron_index]
#     print(f"Neuron {neuron_index} spiked at times {spikes}")
    print(len(spikes))

115
24
18
7
6
114
7
15
13
0
16
104
109
8
106
24
28
119
108
104
8
4
110
111
111
34
5
111
25
0
7
3
114
0
118
13
127
106
7
112
107
115
6
178
0
21
4
121
0
105


In [11]:
segment_length = len(train_spikes[0]) // 5 
print(segment_length)
duration_of_data_point = 50*ms

recorded_spikes = []

idx = 0
for data_point in test_spikes:
    
    for i in range(num_input_neurons):
        start, end = i * segment_length, (i + 1) * segment_length
        segment = data_point[start:end]
        indices, times = convert_to_spike_times(segment, 5*ms)
#         print(times)
        
        input_neurons[i].set_spikes(indices, times)
    
    net.run(duration_of_data_point)
    net.t_ = 0*ms
    recorded_spikes.append(M.spike_trains())


    net.remove(M)
    M = SpikeMonitor(G)
    net.add(M)
    
    # Reset neuron states
    G.v = '0'

    for inp_neuron in input_neurons:
        inp_neuron.set_spikes([0], [0*ms])
        
    if idx % 10 == 0:
        print(idx)
    idx += 1

20
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190


In [12]:
x_test = recorded_spikes_to_dataset(recorded_spikes, n_out)
print(np.shape(x_test))
x_test_snn = torch.from_numpy(x_test).clone()
torch.save(x_test_snn, 'x_test_snn.pt')
X_test_snn = torch.load('x_test_snn.pt')

(199, 10)


In [13]:
train_labels = torch.load('train_labels.pt')
test_labels = torch.load('test_labels.pt')

In [14]:
# 
X_train = X_train_snn.numpy()
y_train = train_labels.numpy()
# Train the model
model = LinearRegression()
model.fit(X_train, y_train)

print(np.shape(X_train))

y_pred = model.predict(X_train)
mse = mean_squared_error(y_train, y_pred)
print("train MSE:", mse)

# print(y_pred[500:550])
# print(y_train[500:550])


(597, 10)
MSE: 847.1319913073451


In [21]:
y_test = test_labels.numpy()
y_pred_test = model.predict(X_test_snn)

mse = mean_squared_error(y_test, y_pred_test)
print("test MSE:", mse)
mse = mean_squared_error(y_test[0:49], y_pred_test[0:49])
print("test MSE 1:", mse)
mse = mean_squared_error(y_test[50:99], y_pred_test[50:99])
print("test MSE 2:", mse)
mse = mean_squared_error(y_test[100:149], y_pred_test[100:149])
print("test MSE 3:", mse)
mse = mean_squared_error(y_test[150:198], y_pred_test[150:198])
print("test MSE 4:", mse)

test MSE: 782.5939669647572
test MSE 1: 895.8751509341905
test MSE 2: 728.6785234649035
test MSE 3: 840.5240212540359
test MSE 4: 642.9507739917561


In [16]:
raw_test = torch.load('raw_test_seqs.pt')

In [17]:
print(raw_test[2,0])

tensor(149.8412)


In [18]:
raw_test_obs = []
for i in range(raw_test.shape[0]):
    raw_test_obs.append(raw_test[i, 0])

In [19]:
def calculate_variance(lst):
    mean = np.sum(lst) / len(lst)
    sum_of_squares = np.sum((x - mean) ** 2 for x in lst)
    var = sum_of_squares / len(lst)

    return var

In [20]:
# compute variance
v1 = calculate_variance(raw_test_obs[0:49])
v2 = calculate_variance(raw_test_obs[50:99])
v3 = calculate_variance(raw_test_obs[100:149])
v4 = calculate_variance(raw_test_obs[150:198])
print(v1)
print(v2)
print(v3)
print(v4)

  sum_of_squares = np.sum((x - mean) ** 2 for x in lst)


tensor(890.2731)
tensor(625.0175)
tensor(699.3123)
tensor(512.1569)
