In [None]:
using Revise
using DrWatson
@quickactivate "SpikingNeuralNetworks"
using SpikingNeuralNetworks
SNN.@load_units
import SpikingNeuralNetworks: AdExParameter, IFParameter, IFConstParameter, AdExConstParameter
using Statistics, Random
using Plots

# LKD model params

In [None]:
## Neuron parameters
τm = 20ms
C = 300pF # Capacitance
R = τm / C

## Neuron parameters
τre = 1ms # Rise time for excitatory synapses
τde = 6ms # Decay time for excitatory synapses
τri = 0.5ms # Rise time for inhibitory synapses 
τdi = 2ms # Decay time for inhibitory synapses

# Input and synapse paramater
N = 1000
νe = 4.5Hz # Rate of external input to E neurons 
νi = 2.25Hz # Rate of external input to I neurons 
p_in = 0.2 #1.0 # 0.5 
σ_in_E = 1.78pF

σEE = 2.76pF # Initial E to E synaptic weight
σIE = 48.7pF # Initial I to E synaptic weight
σEI = 1.27pF # Synaptic weight from E to I
σII = 16.2pF # Synaptic weight from I to I

Random.seed!(23)
duration = 700ms
pltdur = 70e1

## IF neuron with constant input current

In [None]:
LKD_IF_const = 
    IFConstParameter(τm = 20ms, Vt = -52mV, Vr = -60mV, El = -62mV, R = 20ms/300pF, τabs = 1ms)

IFNeuron = SNN.IFConst(; N = 1, param = LKD_IF_const)

In [None]:
#
P = [IFNeuron]
C = []

SNN.monitor([IFNeuron], [:v, :fire]) 
SNN.sim!(P, C; duration = duration)

In [None]:
p1 = plot(SNN.vecplot(IFNeuron, :v),
xlabel = "Time (ms)", 
ylabel = "Membrane Potential (mV)", 
title = "IF neuron with constant input")

## AdEx neuron with constant input current

In [None]:
LKD_AdEx_const = 
    AdExConstParameter(τm = 20ms, Vt = -52mV, Vr = -60mV, El = -70mV, R = 20ms/300pF, ΔT = 2mV, a = 4nS, b = 0.805pA, 
    τw = 150ms, τabs = 1ms, At = 10mV, τT = 30ms)

AdExNeuron = SNN.AdExConst(; N = 1, param = LKD_AdEx_const)

In [None]:
P = [AdExNeuron]
C = []

SNN.monitor([AdExNeuron], [:v, :fire, :w]) 
SNN.sim!(P, C; duration = duration)

In [None]:
p1 = plot(SNN.vecplot(AdExNeuron, :v),
xlabel = "Time (ms)", 
ylabel = "Membrane Potential (mV)")
p2 = plot(SNN.vecplot(AdExNeuron,:w),
xlabel = "Time (ms)",
ylabel = "Adaptation current (pA)")
plot(p1,p2, size=(1100,450), title = "AdEx neuron with constant input")

## Adex neuron with Poisson input

In [None]:
LKD_AdEx_exc = 
    AdExParameter(τm = 20ms, Vt = -52mV, Vr = -60mV, El = -70mV, R = 20ms/300pF, ΔT = 2mV, τw = 150ms, a = 4nS,
    b = 0.805pA, τabs = 1ms, τre = τre, τde = τde, τri = τri, τdi = τdi, E_i = -75mV, E_e = 0mV, At = 10mV, τT = 30ms)

E = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)

Input_E = SNN.Poisson(; N = 500, param = SNN.PoissonParameter(; rate = νe))
ProjE = SNN.SpikingSynapse(Input_E, E, :ge; σ = σ_in_E, p = p_in) # connection from input to E

In [None]:
P = [E, Input_E]
C = [ProjE]

SNN.monitor([E], [:v, :fire, :w]) 
SNN.sim!(P, C; duration = duration)

In [None]:
p1 = plot(SNN.vecplot(E,:v),
xlabel = "Time (ms)", 
ylabel = "Membrane Potential (mV)")
p2 = plot(SNN.vecplot(E,:w),
xlabel = "Time (ms)",
ylabel = "Adaptation current (pA)")
plot(p1,p2, size=(1100,450), title = "AdEx neurons with Poisson inputs")

## IF neuron with Poisson input

In [None]:
LKD_IF_inh =
    IFParameter(τm = 20ms, Vt = -52mV, Vr = -60mV, El = -62mV, R = 20ms/300pF, ΔT = 2mV, 
    τre = τre, τde = τde, τri = τri, τdi = τdi, E_i = -75mV, E_e = 0mV, τabs = 1ms)

I = SNN.IF(; N = 1, param = LKD_IF_inh)

Input_I = SNN.Poisson(; N = 500, param = SNN.PoissonParameter(; rate = νi))
ProjI = SNN.SpikingSynapse(Input_I, I, :ge; σ = σ_in_E, p = p_in)

In [None]:
P = [I, Input_I]
C = [ProjI]

SNN.monitor([I], [:v, :fire]) 
SNN.sim!(P, C; duration = duration)

In [None]:
p1 = plot(SNN.vecplot(I,:v),
xlabel = "Time (ms)", 
ylabel = "Membrane Potential (mV)", 
title = "IF neuron and Poisson inputs")

## AdEx excitatory and IF inhibitory

In [None]:
E = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)
I = SNN.IF(; N = 1, param = LKD_IF_inh)

Input_E = SNN.Poisson(; N = 600, param = SNN.PoissonParameter(; rate = νe))
ProjE = SNN.SpikingSynapse(Input_E, E, :ge; σ = σ_in_E, p = p_in) # connection from input to E
Input_I = SNN.Poisson(; N = 150, param = SNN.PoissonParameter(; rate = νi))
ProjI = SNN.SpikingSynapse(Input_I, I, :ge; σ = σ_in_E, p = p_in)

EI = SNN.SpikingSynapse(E, I, :ge; σ = σEI, p = 1.0)
IE = SNN.SpikingSynapse(I, E, :gi; σ = σIE, p = 1.0, param=SNN.iSTDPParameter())
EE = SNN.SpikingSynapse(E, E, :ge; σ = σEE, p = 1.0, param=SNN.vSTDPParameter())
II = SNN.SpikingSynapse(I, I, :gi; σ = σII, p = 1.0)

In [None]:
P = [E, I, Input_E, Input_I]
C = [EE, II, EI, IE, ProjE, ProjI]

SNN.monitor([E, I], [:v, :fire]) 
SNN.train!(P, C; duration = duration)

In [None]:
p1 = plot(SNN.vecplot([E], :v),
xlabel = "Time (ms)", 
ylabel = "Membrane Potential (mV)", 
title = "AdEx neuron coupled with IF neuron with Poisson inputs")
p2 = plot(SNN.vecplot([I], :v),
xlabel = "Time (ms)",
ylabel = "Membrane Potential (mV)", 
title = "IF neuron coupled with AdEx neuron with Poisson inputs")
plot(p1, p2, size=(1500, 450))

## Exploring weights

In [None]:
PrePost_weights = Vector{Float32}()
PostPre_weights = Vector{Float32}()


for fr in 1Hz:0.5Hz:50Hz

    Pre = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)
    # Post = SNN.IF(; N = 1, param = LKD_IF_inh) 
    Post = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)

    Input_Pre = SNN.Poisson(; N = 600, param = SNN.PoissonParameter(; rate = fr))
    ProjPre = SNN.SpikingSynapse(Input_Pre, Pre, :ge; σ = σ_in_E, p = p_in) # connection from input to E
    Input_Post = SNN.Poisson(; N = 150, param = SNN.PoissonParameter(; rate = fr))
    ProjPost = SNN.SpikingSynapse(Input_I, I, :ge; σ = σ_in_E, p = p_in)

    PrePost = SNN.SpikingSynapse(Pre, Post, :ge; σ = σEI, p = 1.0, param=SNN.vSTDPParameter())
    # IE = SNN.SpikingSynapse(I, E, :gi; σ = σIE, p = 1.0, param=SNN.iSTDPParameter())
    # PrePre = SNN.SpikingSynapse(Pre, Pre, :ge; σ = σEE, p = 1.0, param=SNN.vSTDPParameter())
    PostPre = SNN.SpikingSynapse(Post, Pre, :ge; σ = σIE, p = 1.0, param=SNN.vSTDPParameter()) # :gi

    P = [Pre, Post, Input_Pre, Input_Post] # , Input_I
    C = [PostPre, PrePost, ProjPre, ProjPost] #  IE, ProjI, PrePre, 

    SNN.monitor([Pre, Post], [:v, :fire]) 
    SNN.train!(P, C; duration = 5second)

    push!(PrePost_weights, PrePost.W[1])
    push!(PostPre_weights, PostPre.W[1])
end

p1 = plot(1:0.5:50, PrePost_weights, title="Weight pre-post", xlabel="Firing rate (Hz)", ylabel="Weight (pF)")
p2 = plot(1:0.5:50, PostPre_weights, title="Weight post-pre", xlabel="Firing rate (Hz)", ylabel="Weight (pF)")
plot(p1, p2, size=(1000, 450))

In [None]:
PrePost_weights = Vector{Float32}()
PostPre_weights = Vector{Float32}()


fr = 5Hz

Pre = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)
# Post = SNN.IF(; N = 1, param = LKD_IF_inh) 
Post = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)

Input_Pre = SNN.Poisson(; N = 600, param = SNN.PoissonParameter(; rate = fr))
ProjPre = SNN.SpikingSynapse(Input_Pre, Pre, :ge; σ = σ_in_E, p = p_in) # connection from input to E
Input_Post = SNN.Poisson(; N = 150, param = SNN.PoissonParameter(; rate = fr))
ProjPost = SNN.SpikingSynapse(Input_I, I, :ge; σ = σ_in_E, p = p_in)

PrePost = SNN.SpikingSynapse(Pre, Post, :ge; σ = σEI, p = 1.0, param=SNN.vSTDPParameter())
# IE = SNN.SpikingSynapse(I, E, :gi; σ = σIE, p = 1.0, param=SNN.iSTDPParameter())
# PrePre = SNN.SpikingSynapse(Pre, Pre, :ge; σ = σEE, p = 1.0, param=SNN.vSTDPParameter())
PostPre = SNN.SpikingSynapse(Post, Pre, :gi; σ = σIE, p = 1.0, param=SNN.vSTDPParameter()) # :gi

P = [Pre, Post, Input_Pre, Input_Post] # , Input_I
C = [PostPre, PrePost, ProjPre, ProjPost] #  IE, ProjI, PrePre, 


SNN.monitor([Pre, Post], [:v, :fire]) 

W = zeros((2,500))
for i in 1:5
    SNN.sim!(P, C; duration = 1second)
    
    W[1,i] = PostPre.W[1]
    W[2,i] = PrePost.W[1]
end
# SNN.sim!(P, C; duration = 5second)

# p1 = plot(W[1,:])
# p2 = plot(W[2,:])
# plot(p1, p2,  size=(1000, 450))

SNN.vecplot([Pre], :v)


# p1 = plot(1:0.5:50, W[], title="Weight pre-post", xlabel="Firing rate (Hz)", ylabel="Weight (pF)")
# p2 = plot(1:0.5:50, PostPre_weights, title="Weight post-pre", xlabel="Firing rate (Hz)", ylabel="Weight (pF)")
# plot(p1, p2, size=(1000, 450))

In [None]:
Pre = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)
Post = SNN.IF(; N = 1, param = LKD_IF_inh) 
# Post = SNN.AdEx(; N = 1, param = LKD_AdEx_exc)

Input_Pre = SNN.Poisson(; N = 400, param = SNN.PoissonParameter(; rate = 10Hz))
ProjPre = SNN.SpikingSynapse(Input_Pre, Pre, :ge; σ = σ_in_E, p = 0.2) # connection from input to E
# Input_Post = SNN.Poisson(; N = 100, param = SNN.PoissonParameter(; rate = 30Hz))
# ProjPost = SNN.SpikingSynapse(Input_Post, Post, :ge; σ = σ_in_E, p = 0.05)

PrePost = SNN.SpikingSynapse(Pre, Post, :ge; σ = σEI, p = 1.0, param=SNN.vSTDPParameter())
PostPre = SNN.SpikingSynapse(I, E, :gi; σ = σIE, p = 1.0, param=SNN.vSTDPParameter()) # SNN.iSTDPParameter()
# PrePre = SNN.SpikingSynapse(Pre, Pre, :ge; σ = σEE, p = 1.0) # , param=SNN.vSTDPParameter()
# PostPost = SNN.SpikingSynapse(Post, Post, :gi; σ = σII, p = 1.0)

P = [Pre, Post, Input_Pre, Input_Post] #
C = [PrePost, PostPre, ProjPre, ProjPost] # PrePre, PostPost, 
println(PrePost.W, PostPre.W)

SNN.monitor([Pre, Post], [:v, :fire]) 
SNN.train!(P, C; duration = 5second)

println(PrePost.W, PostPre.W)

p1 = plot(SNN.vecplot([Pre], :v))
p2 = plot(SNN.vecplot([Post], :v))
plot(p1, p2, size=(1000, 450))