# Sparse Spiking Ensemble

- Ensemble = Sensory neurons + latent neurons
- Sensory neurons get spikes from sensory inputs
- Sensory inputs are sparse coded (~20% 1s, rest 0s)
- Each neuron also takes input from the whole ensemble (later to be restricted locally)


In [15]:
import torch

SENSORY_MODALITY_1_SIZE = 100

SENSORY_MODALITY_1_NEURONS = 2
LATENT_NEURONS = 3

SPARSITY = 0.1
NEGATIVE_WEIGHT = 0.2

ENSEMBLE_NEURONS = SENSORY_MODALITY_1_NEURONS + LATENT_NEURONS

def make_sparse_embedding(size):
  return (torch.FloatTensor(size).uniform_() < SPARSITY).long().float()

MODALITY_1_SYMBOLS = {
    'A': make_sparse_embedding(SENSORY_MODALITY_1_SIZE),
    'B': make_sparse_embedding(SENSORY_MODALITY_1_SIZE),
    'C': make_sparse_embedding(SENSORY_MODALITY_1_SIZE),
    'D': make_sparse_embedding(SENSORY_MODALITY_1_SIZE),
    'E': make_sparse_embedding(SENSORY_MODALITY_1_SIZE),
}

class Network:
  def __init__(self):
    self.activation = torch.zeros(ENSEMBLE_NEURONS)

    self.afferent_weights = torch.zeros((SENSORY_MODALITY_1_NEURONS, SENSORY_MODALITY_1_SIZE)).uniform_() - NEGATIVE_WEIGHT  # weights [-NEGATIVE_WEIGHT, 1-NEGATIVE_WEIGHT], e.g. [-0.2, 0.8]
    self.lateral_weights  = torch.zeros((ENSEMBLE_NEURONS, ENSEMBLE_NEURONS)).uniform_() - NEGATIVE_WEIGHT
    self.lateral_weights[torch.eye(ENSEMBLE_NEURONS).byte()] = 0

    self.afferent_trace = torch.zeros(self.afferent_weights.shape)
    self.lateral_trace = torch.zeros(self.lateral_weights.shape)
    self.trace_alpha = 0.1
    self.activation_alpha = 0.01
    
  def present_input(self, sensory_input):
    weighted_input = self.afferent_weights * sensory_input
    self.afferent_trace += weighted_input * self.trace_alpha
    self.activation[0:SENSORY_MODALITY_1_NEURONS] = self.activation[0:SENSORY_MODALITY_1_NEURONS] * (1 - self.activation_alpha) + weighted_input.sum(dim=1) * self.activation_alpha

network = Network()
#print("afferant trace", network.afferent_trace)
print("activation", network.activation)
for i in range(100):
  network.present_input(MODALITY_1_SYMBOLS['A'])
  #print("afferant trace", network.afferent_trace)
  print("activation", network.activation)

activation tensor([0., 0., 0., 0., 0.])
activation tensor([0.0340, 0.0300, 0.0000, 0.0000, 0.0000])
activation tensor([0.0676, 0.0597, 0.0000, 0.0000, 0.0000])
activation tensor([0.1009, 0.0890, 0.0000, 0.0000, 0.0000])
activation tensor([0.1339, 0.1181, 0.0000, 0.0000, 0.0000])
activation tensor([0.1665, 0.1469, 0.0000, 0.0000, 0.0000])
activation tensor([0.1989, 0.1754, 0.0000, 0.0000, 0.0000])
activation tensor([0.2309, 0.2037, 0.0000, 0.0000, 0.0000])
activation tensor([0.2625, 0.2316, 0.0000, 0.0000, 0.0000])
activation tensor([0.2939, 0.2593, 0.0000, 0.0000, 0.0000])
activation tensor([0.3249, 0.2867, 0.0000, 0.0000, 0.0000])
activation tensor([0.3557, 0.3138, 0.0000, 0.0000, 0.0000])
activation tensor([0.3861, 0.3406, 0.0000, 0.0000, 0.0000])
activation tensor([0.4162, 0.3672, 0.0000, 0.0000, 0.0000])
activation tensor([0.4460, 0.3935, 0.0000, 0.0000, 0.0000])
activation tensor([0.4756, 0.4195, 0.0000, 0.0000, 0.0000])
activation tensor([0.5048, 0.4453, 0.0000, 0.0000, 0.0000])
