In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm.notebook import trange
from module import *

In [2]:
data_rows = {
    "e":  {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1},
    "r":  {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 0},
    "c":  {0: 0, 1: 1, 2: 0, 3: 1, 4: 0, 5: 1},
    "cr": {0: 0, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0},
    "l":  {0: 1, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1},
    "lr": {0: 1, 1: 0, 2: 1, 3: 0, 4: 1, 5: 0},
    "lc": {0: 1, 1: 1, 2: 0, 3: 0, 4: 0, 5: 1},
    "f":  {0: 1, 1: 1, 2: 1, 3: 0, 4: 0, 5: 0},
    "off":{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
}
numbers = {
    0: "f lr lr lr f",
    1: "c c c c c",
    2: "f r f l f",
    3: "f r f r f",
    4: "lr lr f r r",
    5: "f l f r f",
    6: "cr l f lr f",
    7: "f r r r r",
    8: "f lr f lr f",
    9: "f lr f r lc",
}
empty = np.array([[0,0,0,1,1,1],[0,0,0,1,1,1],[0,0,0,1,1,1],[0,0,0,1,1,1],[0,0,0,1,1,1]])

numbers_rows = {k: [data_rows[vv] for vv in v.split(" ")] for k, v in numbers.items()}

from ipywidgets import SelectMultiple, Select, Button, HBox, VBox, Output, IntRangeSlider, Checkbox, Accordion
def ui(data):
    def update_graph(b):
        fig, ax = plt.subplots(figsize=(8,8))
        s = np.concatenate(([x.value for x in select_nodes]))
        if np.isin('---', s):
            s = s[np.where(s != '---')].astype(int)
        for n in s:
            t = data.index.tolist()
            ax.step(t[xw.value[0]:xw.value[1]], data[n][xw.value[0]:xw.value[1]])
        ax.legend([int(n) for n in s])
        with graph:
            graph.clear_output()
            display(plt.show())


    graph = Output()
    layer_select = Select(
        options=net.layers.layer.unique(),
        value=-1,
        description='Слой'
    )

    type_chboxs = []
    for t in net.nodes.type.unique():
        type_chboxs.append(Checkbox(
            value=True,
            description=t,
            disabled=False
        ))
    select_nodes = []
    for t in type_chboxs:
        v = t.value
        if v:
            select_nodes.append(
                SelectMultiple(
                    options=np.insert(np.array(net.nodes.query(f"type=='{t.description}'").index.tolist(),dtype=object), 0, '---'),
                    description=t.description
                )
            )
            select_nodes[-1].rows = min(len(select_nodes[-1].options), 10)

    def on_layer_change(change):
        select_nodes = []
        for t in type_chboxs:
            v = t.value
            if v:
                select_nodes.append(
                    SelectMultiple(
                        options=np.insert(np.array(net.nodes.query(f"type=={t.description}").index.tolist(),dtype=object), 0, '---'),
                        description=t.description
                    )
                )
                select_nodes[-1].rows = min(len(select_nodes[-1].options), 10)

    layer_select.observe(on_layer_change, names='value')

    xw = IntRangeSlider(
        min=data.index.min(),
        max=data.index.max(),
        value=(data.index.min(), data.index.max()),
        description='Окно просмотра',
        disabled=False
    )
    select = Accordion(children=[VBox(type_chboxs), VBox(select_nodes)])
    select.selected_index=1
    select.set_title(0, 'Типы')
    select.set_title(1, 'Ноды')
    xw.observe(update_graph, names='value')
    [x.observe(update_graph, names='value') for x in select_nodes]
    draw = Button(description='Draw')
    draw.on_click(update_graph)
    return HBox((VBox((xw,select,draw)), graph))

In [3]:
row_time = 10
silent_time = 100
params_dendrites = {
    "tau_leak": 2*row_time, 
    "tau_inhibitory": 1,
    "tau_refractory": 1,
    "tau_ltp": 2*row_time, 
    "thres": 1150,
    "ainc": 0.7, 
    "adec": -3, 
    "wmax": 255, 
    "wmin": 1,
    "learning": True,
    "wta": True,
    "layer_type": "ttron"
}
dendrites = [
    {
        "connections":{
            0: [0, 1],
            1: [0, 1],
            2: [0, 1],
            3: [0, 1],
            4: [0, 1],
            5: [0, 1],
        }
    }, {
        "connections":{
            0: [0, 1],
            1: [0, 1],
            2: [0, 1],
            3: [0, 1],
            4: [0, 1],
            5: [0, 1],
        }
    }, {
        "connections":{
            0: [0, 1],
            1: [0, 1],
            2: [0, 1],
            3: [0, 1],
            4: [0, 1],
            5: [0, 1],
        }
    }, {
        "connections":{
            0: [0, 1],
            1: [0, 1],
            2: [0, 1],
            3: [0, 1],
            4: [0, 1],
            5: [0, 1],
        }
    }, 
]

# 1 Подготовка новой сети

## Генерация

In [4]:
def new_network():
    net = SpikeNetworkSim(inputs_l=6, dt=1)
    teacher_nodes, who_is_who = net.ttron_layer(num_nodes=4, num_cat_inputs=4, delay_depth=6, **params_dendrites)
    return net, teacher_nodes, who_is_who
net, teacher_nodes, who_is_who = new_network()

In [5]:
net.nodes.loc[net.nodes.type=="presynaptic"]

Unnamed: 0,type,listening,broadcasting,priority,layer
89,presynaptic,"[0, 1, 2, 3, 4, 5, 36, 37, 38, 39, 40, 41]",[92],2,0
96,presynaptic,"[0, 1, 2, 3, 4, 5, 36, 37, 38, 39, 40, 41]",[99],2,0
103,presynaptic,"[0, 1, 2, 3, 4, 5, 36, 37, 38, 39, 40, 41]",[106],2,0
110,presynaptic,"[0, 1, 2, 3, 4, 5, 36, 37, 38, 39, 40, 41]",[113],2,0


# 0 Инициализация наборов данных

## Обучающий датасет

In [6]:

genome = {
    1:{25:1, 45:1}, 
    2:{25:3, 45:2},
    3:{25:3, 45:3}, 
    4:{25:0, 45:3}, 
    5:{25:2, 45:3},
    6:{25:2, 45:0},
    7:{25:3, 45:1}, 
    8:{25:0, 45:0}, 
    #9:{25:0, 45:1}, 
    #0:{25:1, 45:1}
}
pattern_sights = {teacher_nodes[k]: (np.concatenate([(g[25], g[45]) for g in genome.values()])==k).sum() for k in range(4)}


In [7]:
def gen_test(n_num):
    seq_data = []
    seq_labels = []
    a = None
    for n in range(n_num):
            a = np.random.randint(1, 9)
            seq_data+=[x.copy() for x in numbers_rows[a]]
            seq_labels+=[a for _ in numbers_rows[a]]
            seq_labels.append(-1)
            seq_data.append(data_rows["off"].copy())

    #genome = {}
    #for x in np.unique(seq_labels):
    #    genome[x] = np.random.randint(len(teacher_nodes))

    nseq_d = []
    nseq_l = []

    for d, l in zip(seq_data, seq_labels):
        if l != -1:
            t = row_time
        else:
            t = silent_time
        nseq_d.append(d.copy())
        nseq_l.append(l)
        for i in range(t-1):
            nseq_d.append(data_rows["off"].copy())
            nseq_l.append(l)

    seq_data = nseq_d
    seq_labels = nseq_l

    d_p = seq_labels[0]
    step = 1
    for i, d in enumerate(seq_labels[1:]):
        seq_data[i].update({k:0 for k in teacher_nodes})
        if d_p == -1 and d != d_p:
            step = 0
        if d in genome:
            if step in genome[d]:
                seq_data[i][teacher_nodes[genome[d][step]]] = 1
        d_p = d
        step += 1
    return seq_data

## Тестовый датасет

In [9]:
nseq_d = []
nseq_l = []
test_data = []
test_labels = []
a = None
for n in range(10):
    a = n
    test_data+=[x.copy() for x in numbers_rows[a]]
    test_labels+=[a for _ in numbers_rows[a]]
    test_labels.append(-1)
    test_data.append(data_rows["off"].copy())

for d, l in zip(test_data, test_labels):
    if l != -1:
        t = row_time
    else:
        t = silent_time
        
    nseq_d.append(d.copy())
    nseq_l.append(l)
    for i in range(t-1):
        nseq_d.append(data_rows["off"].copy())
        nseq_l.append(l)

test_data = nseq_d
test_labels = nseq_l
step = 0
d_p = 0
for i, d in enumerate(test_labels[1:]):
    test_data[i].update({k:0 for k in teacher_nodes})
    if d_p == -1 and d != d_p:
        step = 0
    if d in genome:
        if step<15:
            test_data[i][teacher_nodes[genome[d][25]]] = 1
        elif step<25:
            test_data[i][teacher_nodes[genome[d][25]]] = 1
            test_data[i][teacher_nodes[genome[d][45]]] = 1     
        else:
            test_data[i][teacher_nodes[genome[d][45]]] = 1            
    d_p = d
    step += 1

# 2 Обучение

## Цикл обучения

In [None]:
ZOO_POP = 0

In [None]:
zoo = []
weights_variance = []
net, _, _ = new_network()
start_weights = net.weights.copy()
for j in trange(ZOO_POP):
    net, _, _ = new_network()
    #net.weights.weights = best_weights.copy()
    error_agg = error_agg_min = 100
    
    
    
    for section in range(15):
        net.layer_params["learning"][-1] = True
        out_learn = net.feed_raw(gen_test(10))
        net.layer_params["learning"][-1] = False
        out_test = net.feed_raw(test_data)
        error = net.error(teacher_nodes, pattern_sights=pattern_sights)
        e = {k: v.copy() for k, v in error.items()}
        mins = []
        for _ in error.keys():
            b = {k: min(e[k], key=lambda x: abs(e[k][x])) for k in e.keys()}
            a = min(b, key=lambda x: abs(e[x][b[x]]))
            mins.append(e[a][b[a]])
            del e[a]
            for k in e.keys():
                del e[k][b[a]]
        mins = np.array(mins)
        error_agg = (mins.mean())
        if error_agg > error_agg_min*1.2:
            ...#break
        error_agg_min = min(error_agg, error_agg_min)
        weights_variance.append(
            {"sp": j, "sec": section, "error": error.copy(), "weights": net.weights.weights.copy(), "error_agg": error_agg}
        )
d2 = pd.DataFrame(weights_variance)

In [None]:
d2 = pd.DataFrame(weights_variance)
for i in range(ZOO_POP):
    plt.plot(d2.loc[d2.sp == i].sec, d2.loc[d2.sp == i].error_agg)


In [None]:
d2.sort_values("error_agg").groupby("sp").head(1).head(10)

In [None]:
out_test = net.feed_raw(test_data)
net.error(teacher_nodes, pattern_sights={132: 4, 133:6, 134:4, 135:6})

In [None]:
net2, _, _ = new_network()
net2.weights.weights = best_weights.copy()
net2.layer_params["learning"][-1] = False
out_test = net2.feed_raw(test_data)
net2.error(teacher_nodes, pattern_sights={84: 4, 85:6, 86:4, 87:6})

In [None]:
ui(out_test)

In [11]:
import ipywidgets as widgets

#загрузить исходные веса
inputs = [[] for _ in range(4)]
array = np.array
save = {89: array([209, 133, 158,   3, 177, 177,  57, 226,   1,  82,  81,   1]),
 96: array([183,  18,   1, 205, 133, 129, 205, 207, 146, 189,  38, 151]),
 103: array([ 94,   1,  31, 194, 219, 143,  60, 219,  41, 120, 207, 207]),
 110: array([205, 126,  93,  11,  11, 183, 158,   7, 182,  19, 220,  43])}
best_weights = save#
#best_weights=d2.loc[d2["error_agg"].idxmin()].weights
weights = best_weights

for i in range(4):
    for _ in range(12):
        inputs[i].append(widgets.BoundedIntText(
            value=best_weights[[89,96,103,110][i]][_],
            min=1,
            max=255,
            step=1,
            description='',
            layout=widgets.Layout(width='50px', height='30px')
        ))
    inputs[i].append(widgets.BoundedFloatText(
        value=1,
        min=0,
        max=10,
        step=0.1,
        description='Cf'
    ))
#
def rescale(nr):
    def inner(_):
        global weights
        for nc in range(12):
            inputs[nr][nc].value = int(inputs[nr][12].value*weights[[89,96,103,110][nr]][nc])
    return inner
#

scalers = []
for i in range(4):    
    scalers.append(rescale(i))
    inputs[i][-1].observe(scalers[i])
#по нажатию кнопки - обновить всё
def get_weights():
    global weights
    weights = dict(list(zip([89,96,103,110], [np.array([c.value for c in row[:12]]) for row in inputs])))
    return weights

def run_new_weights(_):
    net, _, _ = new_network()
    net.weights.weights = get_weights()
    net.layer_params["learning"][-1] = False
    
    out_test = net.feed_raw(test_data)
    fig, ax = plt.subplots(figsize=(8,8))
    s = [92,99,106,113]
    for n in s:
        ax.step(out_test.index.tolist(), out_test[n])
    ax.legend([1, 2, 3, 4])
    with graph:
        graph.clear_output()
        display(plt.show())
    for row in inputs:
        row[12].value = 1

        
draw = Button(description="Draw")
draw.on_click(run_new_weights)
graph = Output()
run_new_weights(None)
VBox([HBox(i) for i in inputs]+[draw, graph])










VBox(children=(HBox(children=(BoundedIntText(value=209, layout=Layout(height='30px', width='50px'), max=255, m…

In [13]:
get_weights()

{89: array([209, 133,  85, 174, 171, 175,  57, 226,  58,  82,  81,   1]),
 96: array([183,  18,   1, 205, 133, 129, 205, 207, 146, 189,  38, 151]),
 103: array([ 94,   1,  31, 194, 219, 143,  14, 184,   9, 120, 207, 207]),
 110: array([ 33, 126, 253,  11,  11, 183, 158,   7, 182,  19, 220,  43])}