In [1]:
import pandas as pd
import numpy as np

In [2]:
DATA_PATH = './'

In [3]:
iris_df = pd.read_csv(DATA_PATH + 'data/iris_encoded.csv', index_col=0)

In [4]:
list_weight = np.loadtxt(DATA_PATH + 'data/weights.csv', delimiter = ',')

In [5]:
iris_data = iris_df.iloc[:-1, :]
iris_labels = iris_df.iloc[-1, :].astype(int)

In [6]:
iris_data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,140,141,142,143,144,145,146,147,148,149
0,,,,,,,,,0.393469,,...,,,,,,,,,,
1,,0.864665,0.0001,0.393469,,,0.393469,,,0.864665,...,,,,,,,,,,
2,0.0001,0.864665,,,0.393469,,,0.393469,,0.864665,...,,,,,,,,,,
3,,,,,,0.393469,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,0.393469,,,,,,,0.0001


In [7]:
def model_data(lat_ne):
    
    train_stack = np.where(lat_ne > 0, lat_ne, 0)
    
    return train_stack      # (150, 40)

In [8]:
def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
    spk = (mem > threshold) # if membrane exceeds threshold, spk=1, else, 0
    # mem = beta * mem + w*x - spk*threshold
    if spk:
        mem = 0
    else:
        mem = beta * mem + w*x
    # mem = beta * mem + w*x
    return spk, mem

In [9]:
def LIF_SNN(n, data, weight, v_spike = 0.25):
    l = len(data)


    V_min = 0
    r = 5
    tau = 2.5
    dt = 0.01
    beta = 1 - dt / tau
    t_max = 10
    time_stamps = t_max / dt
    # time_relax = 10
    v = np.zeros((n, l, int(time_stamps)))
    t_post = np.zeros((n, l))
    t_post_ = np.zeros((n, int(l / 3)))
    v[:, :, 0] = V_min
    
    # data_spike_list = np.zeros((l, n, len(data[0]), int(time_stamps)))
    data_spike_list_sum = np.zeros((l, n, int(time_stamps)))

    for u in range(l):          # data point
        for ni in range(n):     # 3 classes
            f0 = (np.round(data[u][np.newaxis].T, 3) * 1000).astype(int)      # (40, 1)
            f1 = np.tile(np.arange(1000), (40, 1))                            # (40, 1)
            f2 = np.where(((f1 == f0) & (f0 > 0)), 1, 0)                      # (40, 1000)        1 wherever the spike is
            f2 = f2 * weight[ni][np.newaxis].T
            # data_spike_list[u][ni] = f2
            data_spike_list_sum[u][ni] = np.sum(f2, axis = 0)

    for u in range(l):          # data point
        for step in range(int(time_stamps) - 1):
            for ni in range(n):     # 3 classes
                spike_list = data_spike_list_sum[u][ni]
                (spk_out, mem_out) = leaky_integrate_and_fire(v[ni, u, step], r * dt/tau * spike_list[step], 1, beta, threshold=v_spike)

                if spk_out:
                    t_post[ni, u] = step

                v[ni, u, step + 1] = mem_out

    return v, t_post

In [10]:
def accuracy_snn(spike_time, iris_labels):
    
    target_type = iris_labels
    
    spike_time_ = np.where(spike_time > 0, np.array(([1], [2], [3])), np.nan)
    final_test = np.full([len(spike_time[0])], np.nan).astype(int)
    for i in range(len(spike_time[0])):
        try:
            final_test[i] = spike_time_[:, i][spike_time[:, i] == np.min(spike_time[:, i][spike_time[:, i] > 0])][0]
        except:
            final_test[i] = 0
    
    ac = np.sum(np.where(final_test == target_type, 1, 0)) / len(target_type)

    print('accur.:', np.round(ac * 100, 2), '%')

    return final_test, target_type

In [11]:
lat_ne = np.transpose(iris_data.values)       # (150, 40)

In [12]:
test_stack = model_data(lat_ne)

res = LIF_SNN(3, test_stack, list_weight)
spike_time = res[1]

out = accuracy_snn(spike_time, iris_labels.values)

accur.: 94.0 %


  final_test = np.full([len(spike_time[0])], np.nan).astype(int)
