## Random pattern classification with surrogate gradient learning

Notebook adapted for Julia from [Friedemann Zenke](https://github.com/fzenke/spytorch/tree/main/notebooks). Tested on Julia v1.7.2. <div style="text-align: right"> &copy; Hartmut Fitz (2024)</div>

### Run between whole experiment!

In [None]:
diff_accuracies_probabilities=[]
diff_smoothed_loss=[]
diff_accuracies_probabilities_pen=[]

In [None]:
prob_change=0.6
# 0.4
# 0.2
# 0.6

### Run between different probabilities!

In [None]:
using Distributions
using Printf
using Flux
using Flux.Optimise: update!
using Flux.Zygote: pullback, gradient, @adjoint, @ignore, Buffer
using Random
using Plots
using Pkg
# Pkg.add("Plots")  # Install if not already done
gr(dpi = 300) # set resolution for plots with gr backend
loss_values_more =[]
accuracies_normal_words = []
accuracies_pen=[]

In [None]:
using Pkg
Pkg.add("Interpolations")

In [None]:
Random.seed!(235)
# 123
# 150
# 2
# 80
# 235

In [None]:
# Import the module
cd("C:/ru-2023-2024/thesis")
include("convert_to_ipa.jl")
using .convert_to_ipa

# Example usage
words = word_list
phon=[]
for w in words
    p= convert_to_ipa.Convert_to_ipa(w)
    append!(phon,p)
end
for p in phon
    print("\"",p,"\"",",")
end

### Parameters

In [None]:
# Simulation parameters
    dt = 0.001          # unit is seconds
    time = 0.1          # each pattern is "on" for 100 ms

# Model parameters
    tau_mem = 10e-3
    beta = exp(-dt/tau_mem)
    num_Hidden = 30

    letter_Nsteps = 25
    num_Inputs = 30

### Preprocess input (functions)

In [None]:

function encode_sequences(sequences, spike_dict, letter_Nsteps, num_Inputs, max_length)
    enc_sequences = []
    for s in sequences
        sequence_spikes = zeros(Int, letter_Nsteps * max_length, num_Inputs)

        for (i, char) in enumerate(s)
            spike_pattern = spike_dict[char]
            start_index = (i - 1) * letter_Nsteps + 1
            end_index = i * letter_Nsteps
            sequence_spikes[start_index:end_index, :] .= spike_pattern
        end

        push!(enc_sequences, sequence_spikes)
    end
    
    return enc_sequences
end










function make_spike_alphabet(alphabet, Nsteps, num_Inputs)
    spike_alphabet = Vector{Array{Int, 2}}(undef, length(alphabet))  
    for (i, letter) in enumerate(alphabet)
        spike_letter = ones(Int8, Nsteps, num_Inputs).*rand(Nsteps, num_Inputs).<0.1
        spike_alphabet[i] = spike_letter
    end
    return spike_alphabet
end

function one_hot_encoding(phoneme_sequences,word_list,mapping)
    one_hot_matrix = zeros(Int8,length(phoneme_sequences), length(word_list))
    for (s,sequence) in enumerate(phoneme_sequences)
        for (w,word) in enumerate(word_list)
            if mapping[sequence]==word
                one_hot_matrix[s,w]=1
            end
        end
    end
    return one_hot_matrix
end


function flip_spikes_with_probability(spike_pattern, probability)
    for i in 1:size(spike_pattern, 1)
        for j in 1:size(spike_pattern, 2)
            if spike_pattern[i, j] == 1
                if rand() < probability
                    spike_pattern[i, j] = 1 - spike_pattern[i, j]  # Flip 0 to 1 and 1 to 0
                end
            
            else
                if rand() < probability/10
                    spike_pattern[i, j] = 1 - spike_pattern[i, j]  # Flip 0 to 1 and 1 to 0
                end
            end
        
        end
    end
    return spike_pattern
end

function add_noice(start_index, end_index, enc_phonemes_word, probability)
    changed_words = Vector{Matrix{Float64}}()

    for (i, word_pattern) in enumerate(enc_phonemes_word)
        letter_of_word= word_pattern[start_index:end_index,:]
        word_pattern[start_index:end_index,:] = flip_spikes_with_probability(letter_of_word, probability)     
        push!(changed_words, word_pattern)
    end
    return changed_words
end
                




### make spike alphabet

In [None]:
dictionary = [
    
        "pɛn","æpəl","bəˈlun","kæt","dɔg","ɛləfənt","flaʊər","gɑrdən","haʊs","maʊntən","tɛn","mɛn","hɛn","dɛn","lɛn","pɛn","jɛn","fɛn"
        # ,"aɪs","ʤəŋgəl","kaɪt","lɛmən","maʊntən",
        # "naɪt","oʊʃən","pɛŋgwən","kwin","reɪnˌboʊ","sən","tri","əmˈbrɛlə","vaɪəˈlɪn","wɔtər","zaɪləˌfoʊn","jɑt","zibrə",
        # "ænt","bərd","klaʊd","dɪˈzərt","dɛzərt","ərθ","fɔrəst","fɔrɪst","greɪp","hɪl","aɪlənd","ʤækət","ʤækɪt",
        # "ˌkæŋgərˈu","laɪən","məŋki","nɛst","ɔrənʤ","ɔrɪnʤ","pəmkɪn","pəmpkɪn","kwɪlt","rɪvər","sneɪk","taɪgər",
        # "junɪˌkɔrn","veɪs","vɑz","hweɪl","weɪl","joʊgərt",  "dɛn",
# "dɛn","hɛn","mɛn","tɛn","jɛn","kɛn","lɛn","fɛn","pæn","pɪn","pən","pi","peɪ","pɛp","pər","pɛt"
       
  
]
 

alphabet = Set{Char}()
for word in dictionary
    for char in word
        push!(alphabet, char)
    end
end

alphabet=collect(alphabet)

spike_alphabet = make_spike_alphabet(alphabet, letter_Nsteps, num_Inputs)
spike_dict = Dict(alphabet[i] => spike_alphabet[i] for i in 1:length(alphabet))

### network parameters

### Manually setting input of words - Training

In [None]:
phoneme_sequences=[
"tɛn","mɛn","hɛn","dɛn","lɛn","pɛn","jɛn","bɛn"
]

word_list = [
   "ben","pen", "ten", "men", "hen", "den", "len",  "yen"
    
]


print(size(phoneme_sequences))
print(size(word_list))

In [None]:


mapping = Dict(
    "tɛn" => "ten",
    "mɛn" => "men",
    "hɛn" => "hen",
    "dɛn" => "den",
    "lɛn" => "len",
    "pɛn" => "pen",
    "jɛn" => "yen",
    "bɛn" => "ben",
    "gɑrdən" => "garden",
    "haʊs" => "house",
    "maʊntən" => "mountain"
)




num_Outputs = length(word_list)











### Functions to process input and generate labels 

In [None]:
pen_list = [
"pɛn","pɛn","pɛn","pɛn","bɛn","bɛn","bɛn","bɛn"
]
max_length_s = maximum(length.(phoneme_sequences))
println(max_length_s)
max_length_pen = maximum(length.(pen_list))
println(max_length_pen)
max_length = max(max_length_pen,max_length_s)
# max_length =  max_length_s
println(max_length)


println(size(phoneme_sequences))

enc_phonemes = encode_sequences(phoneme_sequences, spike_dict, letter_Nsteps, num_Inputs, max_length)
enc_phonemes_pen = encode_sequences(pen_list, spike_dict, letter_Nsteps, num_Inputs, max_length)

println("type:", size(enc_phonemes[1]))
println("type:", size(enc_phonemes_pen[1]))
index_letter_to_change = 1

start_index = (index_letter_to_change - 1) * letter_Nsteps + 1
end_index = index_letter_to_change * letter_Nsteps





accent_encoded_pen = add_noice(start_index, end_index, enc_phonemes_pen, prob_change)


append!(enc_phonemes,accent_encoded_pen)
append!(phoneme_sequences,pen_list)

trainSize = length(phoneme_sequences)
println(trainSize)
target=one_hot_encoding(phoneme_sequences,word_list,mapping)
println(target)

rows, cols = size(enc_phonemes[1])
Nsteps = rows
x = Array{Int8, 3}(undef, rows, cols, trainSize)
for i in 1:trainSize
    x[:, :, i]  .= enc_phonemes[i]
end

In [None]:
print(size(phoneme_sequences))

### Plot Input Spike

In [None]:

# Plot first input spike pattern, sanity check
function plotInputPattern(x)
    dat = copy(x)
    dat[dat .< 1]  .= -5
    p1 = scatter(dat[:,1], color = "black", markersize = 4, alpha = 0.5, ylims = (-1, num_Inputs+1), legend = false, 
        xlabel = "Time(ms)", ylabel = "Input neuron")
    for j = 2:num_Inputs
        scatter!(j*dat[:,j], color = "black", markersize = 4, alpha = 0.5)
    end
    return p1
end

# for i in 1:trainSize
#     display(plotInputPattern(x[:, :, i]))
# end
# poriginalben=plotInputPattern(x[:,:,2])
# poriginal=plotInputPattern(x[:,:,1])

# p1 = plotInputPattern(x[:,:,12])
# p2 = plotInputPattern(x[:,:,13])
# p3 = plotInputPattern(x[:,:,14])
# p4 = plotInputPattern(x[:,:,15])

# p2 = plotInputPattern(x[:,:,2])
# p3= plotInputPattern(x[:,:,3])
# display(poriginal)
# display(poriginalben)
# display(p1)
# display(p2)
# display(p3)
# display(p4)



### Create network

In [None]:
## Create two sets of weights from input to hidden layer, a, and from hidden to output layer, b
    weight_scale = 7*(1.0 - beta)                  # this should give us some spikes to begin with, magic number 7 comes from Zenke
    d = Normal(0, weight_scale/sqrt(num_Inputs))   # sample weights from normal distribution with mean 0 and custom variance
    
    
    a = Float32.(rand(d, num_Hidden, num_Inputs))  # input --> hidden weights
    b = Float32.(rand(d, num_Outputs, num_Hidden)) # hidden --> output weights
    keep_a = copy(a);                              # to check that SG is working, we focus on changes in the "input" weights only

### Spike function and surrogate derivative

In [None]:
# the spike function
function spike(x::Vector{Float64})
    (x .> 0)*1            # first creates Boolean which becomes numerial through multiplication; avoid array mutation
end

# the derivative of the surrogate
function SGgrad(x::Vector{Float64})
    scale = 10                      # magic scaling factor from Zenke, determines steepness of slope 
    @. inv((scale*abs(x) + 1.0)^2)  # derivative of the sigmoid below is:  - x /|x|*(|x| + 1)^2 ; for x < 0 this simplifies to 1/(|x| + 1)^2 which is used here; note: derivative of |x| is x/|x|
end

# set up the custom adjoint; this tells Zygote that the spike function has derivative SGgrad
Flux.Zygote.@adjoint function spike(x)
    spike(x), Δ -> (Δ.*SGgrad(x), )
end

### The model

In [None]:
@time function snn(x::Matrix{Int8}, a::Matrix{Float32}, b::Matrix{Float32}; rec = false)
    
    if rec
        # Record membrane potentials and output spikes
        mem_rec = zeros(Nsteps, num_Hidden)
        spk_rec = similar(mem_rec)
        out_rec = zeros(Nsteps, num_Outputs)
    end

    # membrane and synaptic time constants
    tau_mem = 10e-3
    tau_syn = 5e-3
    syn_decay = Float32(exp(-dt/tau_syn))
    mem_decay = Float32(exp(-dt/tau_mem))

    out_rec = zeros(Nsteps, num_Outputs)
    # set initial maximum for the two output classes
    max_cl = zeros(num_Outputs)
        # pre-allocate network arrays, using small floats for speed is probably just silly
        syn = Float32.(zeros(num_Hidden))
        mem = Float32.(zeros(num_Hidden));
        mthr = Float32.(zeros(num_Hidden))
        output = Float32.(zeros(num_Outputs))
        flt = Float32.(zeros(num_Outputs))

        for k = 1:Nsteps

            mthr = mem .- 1.0
            mthr = spike(mthr)

            mem =  (mem_decay*mem + syn)        # update hidden layer membrane potentials
            @ignore mem = mem .* (1 .- mthr)    # don't propagate gradients through the spike reset
            syn = syn_decay*syn + a*x[k, :]     # update input currents to hidden layer

            output = mem_decay*output + flt
            flt = syn_decay*flt + b*mthr        # update filtered output

            if rec
                @ignore mem_rec[k , :] .= mem
                @ignore spk_rec[k , :] .= mthr
                @ignore out_rec[k , :] .= output
            end

            max_cl += output

  
        end
        if rec
            return max_cl./Nsteps, mem_rec, spk_rec, out_rec
        else
            return max_cl./Nsteps
        end
end

keep_output, mem_rec, spk_rec, out_rec = snn(x[:,:,2],a,b; rec = true);
# println("keep_output: ", keep_output)
# println("mem_rec: ", mem_rec)
# println("spk_rec: ", spk_rec)
# println("out_rec: ", out_rec)
# maximum(keep_output, dims=1)
# plot the previous stuff as a sanity check; now we only run one pattern at a time
function plotMembrane(mem::Matrix{Float64}, Nsteps::Int64, spk=nothing)
    mem = copy(mem)
    if spk != nothing
        spike_height = 5
        mem[spk .> 0] .= spike_height
    end
    lower, upper = extrema(mem)
    offset = upper/10
    plot(mem[1:Nsteps, :], lw = 2, ylim = (lower - offset, upper + offset), legend = false)
end
p2 = plotMembrane(mem_rec, Nsteps, spk_rec)
p3 = plotMembrane(out_rec, Nsteps)
p4 = plot(p2, p3, layout = (1,2), xlabel = "Time (ms)", ylabel = "Membrane voltage (a.u.)", plot_title = "Hidden layer and readout before training")
display(p4)

In [None]:

numEpochs = 5

### Training loop

In [None]:
loss_values = Float64[]
function trainMe(x::Array{Int8, 3}, a::Array{Float32}, b::Array{Float32}, target::Array{Int8, 2})

    optimizer = ADAM(0.0015, (0.9, 0.999))  # first parameter is the learning rate, second parameter tuple controls history dependence of the optimizer

        θ = Flux.params(a,b)
        loss_hist = []
        train_loss = 0

        function loss(x::Array{Int8, 2}, a::Array{Float32}, b::Array{Float32}, y::Vector{Int8})
            lambda = 0.07                   # scaling factor for regularization
            out = snn(x, a, b; rec = false) # model output
            println("out:",out)
            Flux.logitcrossentropy(out, y) + lambda * sqrt(sum(abs2, a))  # L2-norm regularization
        end

  
    for k=1:numEpochs
        
        sleep(1)
        flush(stdout)
        accum_loss = 0

        for i in 1:trainSize
        # compute gradient of loss evaluated at sample
            train_loss, back = pullback(θ) do
                # print(loss(x[:,:,i], a, b, target[i,:]))
                loss(x[:,:,i], a, b, target[i,:])
            end
            # println("word",i,":",train_loss,"::")
            accum_loss += train_loss
            #println(θ)
            update!(optimizer, θ, back(1. ))
        end
    push!(loss_hist, accum_loss/trainSize)
    println("Epoch ", k, " loss: ", accum_loss)
    #println("Epoch ", k, " loss: ", loss_hist)
    push!(loss_values,accum_loss)
    end # epochs
    return loss_hist
end # end train
@time loss_hist = trainMe(x, a, b, target);
push!(loss_values_more,loss_values)

In [None]:
push!(loss_values_more,loss_values)
print(size(loss_values_more))

### loss over epochs most recent run

In [None]:


plot(1:numEpochs, loss_values, seriestype=:line, xlabel="Epochs", ylabel="Loss", title="Loss over Epochs", grid=true)


In [None]:
println(loss_values)
println(size(loss_values))

println(size(loss_values_more))


### loss over epochs per different seed

In [None]:
using Plots

n = length(loss_values_more)
plot_layout = @layout [a b c; d e ]

p = plot(layout=plot_layout)

for i in 1:n
    loss_values = loss_values_more[i]
    plot!(p, 1:numEpochs, loss_values, seriestype=:line, xlabel="Epochs", ylabel="Loss", title="Loss over Epochs $(i)", grid=true, subplot=i)
end
display(p)


### average loss (between different probs)

In [None]:

numSeries = length(loss_values_more)
average_loss = zeros(numEpochs)
for i in 1:numEpochs
    for j in 1:numSeries
        average_loss[i] += loss_values_more[j][i]
    end
    average_loss[i] /= numSeries
end

plot(1:numEpochs, average_loss, seriestype=:line, xlabel="Epochs", ylabel="Average Loss", title="Average Loss over Epochs", grid=true,  legend=false)

### Smoothed average loss

In [None]:
using Plots

numSeries = length(loss_values_more)
average_loss = zeros(numEpochs)
for i in 1:numEpochs
    for j in 1:numSeries
        average_loss[i] += loss_values_more[j][i]
    end
    average_loss[i] /= numSeries
end

function moving_average(data, window_size)
    smoothed_data = zeros(length(data))
    for i in 1:length(data)
        start_idx = max(1, i - div(window_size, 2))
        end_idx = min(length(data), i + div(window_size, 2))
        smoothed_data[i] = mean(data[start_idx:end_idx])
    end
    return smoothed_data
end

window_size = 
smoothed_loss = moving_average(average_loss, window_size)
plot(1:numEpochs, smoothed_loss, seriestype=:line, xlabel="Epochs", ylabel="Loss", title="Smoothed Loss over Epochs", legend=false)



In [None]:
push!(diff_smoothed_loss,smoothed_loss)

### smoothed loss for all probs

In [None]:
# print(diff_smoothed_loss)
print(size(diff_smoothed_loss))

In [None]:
using Plots
n = length(diff_smoothed_loss)
plot_layout = @layout [a b c]
p = plot(layout=plot_layout)

smooth_loss1 = diff_smoothed_loss[1]
plot!(p, 1:numEpochs, smooth_loss1, seriestype=:line, xlabel="Epochs", ylabel="Loss", title="p=0.60", grid=true, subplot=1, legend=false)
# smooth_loss2 = diff_smoothed_loss[2]
# plot!(p, 1:numEpochs, smooth_loss2, seriestype=:line, xlabel="Epochs", ylabel="Loss", title="p=0.40", grid=true, subplot=2, legend=false)
# smooth_loss3 = diff_smoothed_loss[3]
# plot!(p, 1:numEpochs, smooth_loss1, seriestype=:line, xlabel="Epochs", ylabel="Loss", title="p=0.60", grid=true, subplot=3, legend=false)
display(p)

### Preparing test phase

In [None]:
phoneme_sequences_test = phoneme_sequences

max_length_s_test = maximum(length.(phoneme_sequences_test))
println(max_length_s_test)
# max_length_pen = maximum(length.(pen_list))
# println(max_length_pen)
# max_length = max(max_length_pen,max_length_s)
max_length_test =  max_length_s_test
println(max_length_test)

  testSize = length(phoneme_sequences_test)
  target_test=one_hot_encoding(phoneme_sequences_test,word_list,mapping)
  print(target_test)
  
  enc_phonemes_test = encode_sequences(phoneme_sequences_test, spike_dict, letter_Nsteps, num_Inputs, max_length_test)
  
  rows_test, cols_test = size(enc_phonemes_test[1])
  Nsteps = rows_test
  x_test = Array{Int8, 3}(undef, rows_test, cols_test, testSize)
  for i in 1:testSize
      x_test[:, :, i]  .= enc_phonemes_test[i]
  end
  


println(target_test)
println(size(x_test))


### Re-run model after training and calculate accuracy

In [None]:
function afterTraining(testSize, target_test, word_list)
    acc = 0
    for j in 1:testSize
        keep_output, mem_rec, spk_rec, out_rec = snn(x_test[:,:,j],a,b; rec = true)

        println(keep_output)
        println(argmax(keep_output),"-target:", argmax(target_test[j,:]) )
        println("choosen score:",maximum(keep_output))
        println("target_score:",keep_output[argmax(target_test[j,:])])
        
        ((argmax(keep_output)) == argmax(target_test[j,:])) && (acc += 1)
        println(acc, word_list[argmax(target_test[j,:])])
        println(word_list[argmax(keep_output)])
        println(acc, phoneme_sequences_test[j])
        

        
    end
return acc/testSize

end


acc = afterTraining(testSize, target_test, word_list)

@printf("\nAccuracy: %.4f\n", acc)
push!(accuracies_normal_words, acc)


### prepare for different 'pen'

In [None]:
pen_list_add = ["bɛn","pɛn","tɛn","mɛn","hɛn","dɛn","lɛn","jɛn"]

enc_phonemes_pen_add = encode_sequences(pen_list_add, spike_dict, letter_Nsteps, num_Inputs, max_length)



accent_encoded_pen = add_noice(start_index, end_index, enc_phonemes_pen_add, prob_change)

testSize_pen = length(pen_list_add)

target_test_pen=one_hot_encoding(pen_list_add,word_list,mapping)


rows, cols = size(accent_encoded_pen[1])
Nsteps = rows
x_test_pen = Array{Int8, 3}(undef, rows, cols, testSize_pen)
for i in 1:testSize_pen
    x_test_pen[:, :, i]  .= accent_encoded_pen[i]
end



### test with different 'pen'

In [None]:
function afterTraining(testSize_pen, target_test_pen, word_list)
    acc_pen = 0
    for j in 1:testSize_pen
        keep_output, mem_rec, spk_rec, out_rec = snn(x_test_pen[:,:,j],a,b; rec = true)

        println(keep_output)
        println(argmax(keep_output),"-target:", argmax(target_test_pen[j,:]) )
        println("choosen score:",maximum(keep_output))
        println("target_score:",keep_output[argmax(target_test_pen[j,:])])
        
        ((argmax(keep_output)) == argmax(target_test_pen[j,:])) && (acc_pen += 1)
        println(acc_pen, word_list[argmax(target_test_pen[j,:])])
        println(word_list[argmax(keep_output)])
        println(acc_pen, pen_list_add[j])
        

        
    end
return acc_pen/testSize_pen

end


acc_pen = afterTraining(testSize_pen, target_test_pen, word_list)

@printf("\nAccuracy: %.4f\n", acc_pen)
push!(accuracies_pen,acc_pen)

### spike pattern

In [None]:

# # Plot first input spike pattern, sanity check
# function plotInputPattern(x)
#     dat = copy(x)
#     dat[dat .< 1]  .= -5
#     p1 = scatter(dat[:,1], color = "black", markersize = 4, alpha = 0.5, ylims = (-1, num_Inputs+1), legend = false, 
#         xlabel = "Time(ms)", ylabel = "Input neuron")
#     for j = 2:num_Inputs
#         scatter!(j*dat[:,j], color = "black", markersize = 4, alpha = 0.5)
#     end
#     return p1
# end

# for i in 1:trainSize
#     display(plotInputPattern(x_test_pen[:, :, i]))
# end



In [None]:
print(accuracies_normal_words)
print(accuracies_pen)


### calculate average accuracies (between probabilities)

In [None]:
using Statistics

avg = mean(accuracies_normal_words)

avg_pen = mean(accuracies_pen)

In [None]:
print(avg)
print(avg_pen)

In [None]:

push!(diff_accuracies_probabilities, avg)
push!(diff_accuracies_probabilities_pen, avg_pen)


### barplot accuracy AFter all probs

In [None]:
using Plots
categories = ["p=0.20"
, "p=0.40", "p=0.60"
]

bar(categories, diff_accuracies_probabilities, legend=false, title="Labeling sequences it has trained with", xlabel="Noise Probability", ylabel="Accuracy")



In [None]:
using Plots
categories = ["p=0.20"
, "p=0.40", "p=0.60"
]

bar(categories, diff_accuracies_probabilities_pen, legend=false, title="Labeling variations of sequences it has trained with", xlabel="Noise Probability", ylabel="Accuracy")



In [None]:
# function afterTraining(trainSize, target)
#     acc = 0
#     for j in 1:trainSize
#         keep_output, mem_rec, spk_rec, out_rec = snn(x[:,:,j],a,b; rec = true)

#         println("argmax output:",argmax(maximum(keep_output, dims=1)[1,:]))
#         println("target[j,;]:",target[j,:])
#         println("argmax(target[j,:]):",argmax(target[j,:]))
#         println("output",keep_output)
#         println("maximum(keep_output, dims=1)[1,:]:",maximum(keep_output, dims=1)[1,:])
#         (argmax(maximum(keep_output, dims=1)[1,:]) == argmax(target[j,:])) && (acc += 1)
#         println(acc, phoneme_sequences[j])
        

#         # println("keep_output: ", keep_output)
#         # println("mem_rec: ", mem_rec)
#         # println("spk_rec: ", spk_rec)
#         # println("out_rec: ", out_rec)
#     end
# return acc/trainSize

#end

### Plot stuff


In [None]:
keep_output, mem_rec, spk_rec, out_rec = snn(x[:,:,2],a,b; rec = true)
p7 = plotMembrane(mem_rec, Nsteps, spk_rec)
p8 = plotMembrane(out_rec, Nsteps)
p9 = plot(p7, p8, layout = (1,2), xlabel = "Time (ms)", ylabel = "Membrane voltage (a.u.)", plot_title = "Hidden layer and readout after training")
display(p9)

p5 = histogram(vcat(keep_a...), bins=:scott, alpha =0.5, label = "Untrained")
histogram!(vcat(a...), bins=:scott, alpha =0.5, label = "Trained", ylabel= "Frequency", xlabel = "Synaptic strength (Input to Hidden)", palette = :PuOr_4)
vline!([mean(vcat(a...))], lw = 3, linestyle = :dash, label = "Mean trained")
vline!([mean(vcat(keep_a...))], lw = 3, linestyle = :dash, label = "Mean untrained")
display(p5)

p6 = plot(loss_hist, ylabel = "MSE", xlabel = "Epoch", legend = false)
display(p6)