In [None]:
from pymonntorch import NeuronGroup, SynapseGroup, NeuronDimension, Recorder, Behavior, EventRecorder
import random
from conex import (
    Neocortex,
    prioritize_behaviors,
)
from conex.behaviors.neurons import (
    SimpleDendriteStructure,
    SimpleDendriteComputation,
    LIF,
    SpikeTrace,
    NeuronAxon,
    Fire,
    KWTA,
    ActivityBaseHomeostasis
)
from conex.behaviors.synapses import (
    SynapseInit,
    WeightInitializer,
    SimpleDendriticInput,
    LateralDendriticInput,
    SimpleSTDP,
    WeightClip
)


import torch
import matplotlib.pyplot as plt

from scipy.stats import poisson

In [None]:
##################################################
# parameters
##################################################
DEVICE = "cpu"
DTYPE = torch.float32
DT = 0.1


TAU_S = 5

# input layer
LIF_INPUT1 = {"R":0 , "threshold":0.5 , "tau":0 , "v_reset":0 , "v_rest":0, "v_init": 0}
LIF_INPUT = LIF_INPUT1

# output layer
LIF_OUTPUT1 = {"R":5 , "threshold":-66 , "tau":10 , "v_reset":-68 , "v_rest":-67, "init_v":"normal(-68,3)"}
# LIF_OUTPUT2 = {"R":5 , "threshold":"normal(-37,15)" , "tau":50 , "v_reset":-65 , "v_rest":-67, "init_v":"normal(-50,10)"}
LIF_OUTPUT2 = {"R":8 , "threshold":-37 , "tau":10 , "v_reset":-65 , "v_rest":-60, "init_v":"normal(-50,20)"}
LIF_OUTPUT = LIF_OUTPUT2

# KWTA
K = 1

# LAT. INH.
LAT_COEF = 50

# ACT BASED HOMOSTASIS
WIN_S = 240
ACT_R = 10
UPD_R = 0.01
DEC_R = 1.0


# STDP
A_PLUS = 0.07
A_MINUS = 0.06



# Patterns 
t=80
n = 12
c = 10
#------------------------------------------------------------
# P1 = [([(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n)]+[0]*n) for i in range(t)] 
# P2 = [([0]*n+[(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n)]) for i in range(t)]
# P3 = [[0 for _ in range(2*n)] for _ in range(t)]
#------------------------------------------------------------
# P1 = [([(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n*4//3)]+[0]*(n*2//3)) for i in range(t)] 
# P2 = [([0]*(n*2//3)+[(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n*4//3)]) for i in range(t)]
# P3 = [[0 for _ in range(2*n)] for _ in range(t)]
#------------------------------------------------------------
P1 = [([(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n*3//2)]+[0]*(n*1//2)) for i in range(t)] 
P2 = [([0]*(n*1//2)+[(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n*3//2)]) for i in range(t)]
P3 = [[0 for _ in range(2*n)] for _ in range(t)]
#------------------------------------------------------------
# P1 = [([(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n*5//4)]+[0]*(n*3//4)) for i in range(t)] 
# P2 = [([0]*(n*3//4)+[(1 if random.uniform(0,1) <= poisson.pmf(k=i,mu=t/2)*c else 0) for _ in range(n*5//4)]) for i in range(t)]
# P3 = [[0 for _ in range(2*n)] for _ in range(t)]


In [None]:
class Current(Behavior):
	def initialize(self, ng):
		self.currentMode = self.parameter("MODE", default="ConstantCurrent")
		if self.currentMode == "ConstantCurrent" :
			self.value = self.parameter("value", None)
			ng.I = ng.vector(self.value)
			ng.inpI = ng.vector(mode=self.value)
		
	def forward(self, ng):
		if self.currentMode == "ConstantCurrent" :
			ng.I = ng.vector(self.value)
			ng.inpI = ng.vector(self.value)
		
class InputPatterns(Behavior):
    def initialize(self, ng:NeuronGroup):
        ng.I = ng.vector(mode="zeros")
        self.pat1 = self.parameter("pat1")
        self.pat2 = self.parameter("pat2")
        self.pat3 = self.parameter("pat3")

    
    def forward(self, ng:NeuronGroup):
        ng.I = self.getPat(ng.iteration,100)
    
    def getPat(self,t,c):
        if ((t//c) % 4 == 0):
            return torch.tensor(self.pat1, dtype=torch.float32)
        elif ((t//c) % 4 == 1):
            return torch.tensor(self.pat3, dtype=torch.float32)
        elif ((t//c) % 4 == 2):
            return torch.tensor(self.pat2, dtype=torch.float32)
        else:
            return torch.tensor(self.pat3, dtype=torch.float32)


In [None]:
class LateralInhibition(Behavior):
    def initialize(self, ng: NeuronGroup):
        self.coef = self.parameter("coef",default=None)
    def forward(self, ng: NeuronGroup):
        tmp = torch.logical_not(ng.spikes).float()*torch.sum(ng.spikes)
        tmp *= self.coef
        tmp *= -1
        ng.I += tmp

In [None]:
class costumKWTA(Behavior):
    """
    KWTA behavior of spiking neurons:

    if v >= threshold then v = v_reset and all other spiked neurons are inhibited.

    Note: Population should be built by NeuronDimension.
    and firing behavior should be added too.

    Args:
        k (int): number of winners.
        dimension (int, optional): K-WTA on specific dimension. defaults to None.
    """

    def __init__(self, k, *args, dimension=None, **kwargs):
        super().__init__(*args, k=k, dimension=dimension, **kwargs)

    def initialize(self, neurons):
        self.k = self.parameter("k", None, required=True)
        self.dimension = self.parameter("dimension", None)
        self.shape = (neurons.size, 1, 1)
        if hasattr(neurons, "depth"):
            self.shape = (neurons.depth, neurons.height, neurons.width)

    def forward(self, neurons):
        will_spike = neurons.v >= neurons.threshold
        v_values = neurons.v

        dim = 0
        if self.dimension is not None:
            v_values = v_values.view(self.shape)
            will_spike = will_spike.view(self.shape)
            dim = self.dimension

        if (will_spike.sum(axis=dim) <= self.k).all():
            return

        _, k_winners_indices = torch.topk(
            v_values, self.k, dim=dim, sorted=False
        )

        ignored = will_spike
        ignored.scatter_(dim, k_winners_indices, False)

        neurons.v[ignored.view((-1,))] = (neurons.v_reset -37)/2


In [None]:
class weight_decay(Behavior):
    def forward(self, sg):
        sg.weights -= 0.0001

In [None]:
class ForcedNeuron(Behavior):

    def __init__(self,R,threshold,tau,v_reset,v_rest,*args,init_v=None,init_s=None,**kwargs):
        super().__init__(*args,R=R,tau=tau,threshold=threshold,v_reset=v_reset,v_rest=v_rest,init_v=init_v,init_s=init_s,**kwargs)

    def initialize(self, neurons):
        self.add_tag(self.__class__.__name__)

        neurons.R = self.parameter("R", default=0, required=False)
        neurons.tau = self.parameter("tau", default=0, required=False)
        neurons.threshold = self.parameter("threshold", default=0.5, required=False)
        neurons.v_reset = self.parameter("v_reset", default=0, required=False)
        neurons.v_rest = self.parameter("v_rest", default=0, required=False)
        neurons.v = self.parameter("init_v", neurons.vector())

        self.pat1 = self.parameter("pat1")
        self.pat2 = self.parameter("pat2")
        self.pat3 = self.parameter("pat3")
        neurons.spikes = self.parameter("init_s", neurons.v >= neurons.threshold)

        neurons.spiking_neuron = self
        self.timeInterval = self.parameter("timeInterval")
        


    def Fire(self, neurons):
        neurons.spikes = neurons.v >= neurons.threshold
        neurons.v[neurons.spikes] = neurons.v_reset

    def forward(self, neurons):
        neurons.v = self.getPat(neurons.iteration,self.timeInterval)
    
    def getPat(self,t,c):
        if ((t//c) % 4 == 0):
            return torch.tensor(self.pat1[t%self.timeInterval], dtype=torch.float32)
        elif ((t//c) % 4 == 1):
            return torch.tensor(self.pat2[t%self.timeInterval], dtype=torch.float32)
        elif ((t//c) % 4 == 2):
            return torch.tensor(self.pat1[t%self.timeInterval], dtype=torch.float32)
        else:
            return torch.tensor(self.pat2[t%self.timeInterval], dtype=torch.float32)

In [None]:
# IN
net = Neocortex(dt=DT, device=DEVICE, dtype=DTYPE)

behavior_in = {
        10: ForcedNeuron(
                R=LIF_INPUT["R"],
                threshold=LIF_INPUT["threshold"],
                tau=LIF_INPUT["tau"],
                v_reset=LIF_INPUT["v_reset"],
                v_rest=LIF_INPUT["v_rest"],
                pat1 = P1, pat2 = P2, pat3 = P3,
                timeInterval=t
                ),
        20: Fire(),
        30: SpikeTrace(tau_s=TAU_S),
        40: NeuronAxon(),
        100: EventRecorder(tag="pop1_evrec", variables=["spikes"])
}

pop_inp = NeuronGroup(
    net=net,
    size=len(P1[0]),
    behavior=behavior_in,
)

In [None]:
# OUT
behavior_out=prioritize_behaviors(
        [
            SimpleDendriteStructure(),
            SimpleDendriteComputation(),
            LIF(
                R=LIF_OUTPUT["R"],
                threshold=LIF_OUTPUT["threshold"],
                tau=LIF_OUTPUT["tau"],
                v_reset=LIF_OUTPUT["v_reset"],
                v_rest=LIF_OUTPUT["v_rest"],
                init_v=LIF_OUTPUT["init_v"]
            ),
            SpikeTrace(tau_s=TAU_S),
            # KWTA(k=K),
            Fire(),
            NeuronAxon(),
            # ActivityBaseHomeostasis(window_size=WIN_S, activity_rate=ACT_R, updating_rate=UPD_R, decay_rate=DEC_R),

        ]
    )

behavior_out[1000] = Recorder(tag="pop2_rec", variables=["v", "I"])
behavior_out[1001] = EventRecorder(tag="pop2_evrec", variables=["spikes"])
# behavior_out[250] = LateralInhibition(coef=LAT_COEF)
# behavior_out[300] = costumKWTA(k=K)


pop_out = NeuronGroup(
    net=net,
    size=2,
    behavior=behavior_out
)


In [None]:
# SYN
behavior_syn=prioritize_behaviors(
        [
            SynapseInit(),
            WeightInitializer(mode="normal(0.3,0.2)"),
            SimpleDendriticInput(current_coef=90),
            SimpleSTDP(a_plus=A_PLUS, a_minus=A_MINUS),
            WeightClip(w_max=1,w_min=0)
        ])
behavior_syn[1000] = Recorder(tag="syn_rec", variables=["weights"])
# behavior_syn[390] = weight_decay()


syn_inp_out = SynapseGroup(
    net=net,
    src=pop_inp,
    dst=pop_out,
    tag="Proximal",
    behavior=behavior_syn,
)

syn_lateral = SynapseGroup(net = net, src = pop_out, dst = pop_out,tag = "Proximal, inh", behavior=prioritize_behaviors([
    SynapseInit(),
    WeightInitializer(weights=torch.Tensor([1, 0, 1]).view(1, 1, 1, 1, 3)),
    LateralDendriticInput(inhibitory=True, current_coef=LAT_COEF),
]))


it = 2000
net.initialize()
net.simulate_iterations(it)


In [None]:
plt.figure().set_size_inches(15,5)
plt.scatter(net["spikes", 1][:,0], net["spikes", 1][:,1],s=0.25, c="teal",alpha=0.8 ,label="posts")
plt.scatter(net["spikes", 0][:,0], net["spikes", 0][:,1]+2, s=0.25, c="palevioletred",alpha=1, label="patterns")
plt.xlabel("t")
plt.ylabel("neuron")
plt.legend()
plt.title("Input and Output Spike Trains")
for i in range(2):
        plt.plot([0, it], [i, i], color="black", alpha=0.1)

FONT = "arial"
parameters = ["PreSize", "TimeWindow", "LateralCoef", "KWTA"]
values = [pop_inp.size, t, LAT_COEF, 0]
table_data = [parameters, values]
table = plt.table(cellText = table_data, cellLoc = "center", loc = "bottom", bbox = [0, -0.3, 1, 0.2])
table.scale(0.8, 0.8)
table.auto_set_font_size(False)
table.set_fontsize(9)
for i in range(len(parameters)):
        table[(0, i)].set_facecolor("whitesmoke")
        table[(0, i)].set_text_props(color = 'black', fontfamily = FONT)

for i in range(len(values)):
        table[(1, i)].set_facecolor("snow" if i % 2 == 0 else '#e6e6e6')
        table[(1, i)].set_text_props(fontfamily = FONT)

plt.show()



showafter=1600
spike = net['pop1_evrec',0]['spikes',0][:,1]
time = net['pop1_evrec',0]['spikes',0][:,0]

plt.figure().set_size_inches(10,5)
plt.scatter(time[time>showafter],spike[time>showafter]+2,s=0.25,c='palevioletred')

spike = net['pop2_evrec',0]['spikes',0][:,1]
time = net['pop2_evrec',0]['spikes',0][:,0]

for i in range(2):
        plt.plot([showafter, it], [i, i], color="black", alpha=0.1)

plt.scatter(time[time>showafter],spike[time>showafter],s=5,c='teal', alpha=0.8)
plt.xlabel("t")
plt.ylabel("neuron")

