# Aggregate-Label Learning (Draft)

A reproduction of the aggregate-label learning task using an implementation of the multi-spike tempotron in Julia. 

For further details see [Gütig, R. (2016). Spiking neurons can discover predictive features by aggregate-label learning. Science, 351(6277), aab4113.](https://science.sciencemag.org/content/351/6277/aab4113)

## Imports

In [1]:
using Tempotrons
using Tempotrons.InputGen
using Tempotrons.Plots
using Tempotrons.Optimizers
using ProgressMeter
using Plots
using Plots.PlotMeasures;

## Set parameters

In [2]:
N                = 500                       # number of afferent neurons
T                = 1000                      # base input duration (in ms)
dt               = 1                         # time precision (in ms, for visualization purposes)
ν                = 5                         # background/event firing frequency (in Hz)
λ                = 0.001                     # learning rate
opt              = SGD(λ, momentum = 0.99)   # optimizer
Nᶠ               = 10                        # number of event types
Tᶠ               = 50                        # event duration (in ms)
Cᶠ_mean          = 0.1                         # mean number of event ocurrances per sample per event type
n_steps          = 20000                     # number of training samples
n_train_samples  = 5000                      # number of training samples
n_test_samples   = 10                        # number of test samples
tmp              = Tempotron(N = N, τₘ = 20) # the multi-spike tempotron
Pretrain!(tmp);                              # initialize the tempotron's weights by a pretraining process (see original paper for details)

## Set teacher's rule
Numbered examples follow [Fig.1](https://science.sciencemag.org/content/351/6277/aab4113#F2)(B) from the [original paper](https://science.sciencemag.org/content/351/6277/aab4113). 

In [3]:
# y₀(event_types)::Integer = isempty(event_types) ? 0 : length(filter(x -> x == 2, event_types));   #1
# y₀(event_types)::Integer = isempty(event_types) ? 0 : 5*length(filter(x -> x == 2, event_types)); #2
# y₀(event_types)::Integer = isempty(event_types) ? 0 : length(filter(x -> x%2 == 0, event_types)); #3
y₀(event_types)::Integer = isempty(event_types) ? 0 : sum(filter(x -> x%2 == 0, event_types)/2);  #4
# y₀(event_types)::Integer = isempty(event_types) ? 0 : sum(filter(x -> x%5 == 0, event_types)/5);

## Generate samples

### Input events

In [4]:
events = GetEvents(Nᶠ = Nᶠ, Tᶠ = Tᶠ, N = N, ν = ν);

### Test samples

In [5]:
test_samples = [GenerateSampleWithEmbeddedEvents(events, Tᶠ = Tᶠ, Cᶠ_mean = Cᶠ_mean, ν = ν, T = T)
                for j = 1:n_test_samples]
test_samples = [(ts..., 
                 y = y₀(ts.event_types), 
                 t = collect(0:dt:maximum(abs, maximum.(abs, ts.x))))
                for ts ∈ test_samples];

In [6]:
# Prepare to plot
plotlyjs(size = (800, 1500))
cols = collect(1:Nᶠ)#palette(:rainbow, Nᶠ)
test_events = [[(time = s.event_times[i], length = Tᶠ, color = cols[s.event_types[i]])
                for i = 1:length(s.event_times)] 
               for s ∈ test_samples]

# Plots
inp_plots = [PlotInputs(ReduceAfferents(test_samples[i].x, 0.1), events = test_events[i])
             for i = 1:length(test_samples)]
plot(inp_plots..., layout = (length(inp_plots), 1), link = :all, bottom_margin = 12mm)

### Training samples

In [7]:
train_samples = @showprogress 1 "Generating samples..." [GenerateSampleWithEmbeddedEvents(events, Tᶠ = Tᶠ, 
                                                                                          Cᶠ_mean = Cᶠ_mean, 
                                                                                          ν = ν, T = T)
                                                         for j = 1:n_train_samples]
train_samples = [(x = ts.x, y = y₀(ts.event_types))
                 for ts ∈ train_samples];

[32mGenerating samples...100%|██████████████████████████████| Time: 0:00:11[39m


## Train

In [None]:
# Train the tempotron
@showprogress 1 "Training..." for i = 1:n_steps
    s = rand(train_samples)
    Train!(tmp, s.x, s.y, optimizer = opt)
end

[32mTraining... 67%|███████████████████████████             |  ETA: 0:13:45[39mm

## Plots

#### Get voltage traces for the test samples

In [None]:
out_a = @showprogress 1 "Evaluating test samples..." [tmp(s.x, t = s.t) for s ∈ test_samples];

#### Prepare to plot

In [None]:
plotlyjs(size = (800, 1500))
cols = collect(1:Nᶠ)#palette(:rainbow, Nᶠ)
test_events = [[(time = s.event_times[i], length = Tᶠ, color = cols[s.event_types[i]])
                for i = 1:length(s.event_times)] 
               for s ∈ test_samples];

#### Plot

In [None]:
inp_plots = [PlotInputs(ReduceAfferents(test_samples[i].x, 0.1), events = test_events[i])
             for i = 1:length(test_samples)]
train_plots = [PlotPotential(tmp, out = out_a[i].V, N = length(out_a[i].spikes), 
                             N_t = test_samples[i].y,
                             t = test_samples[i].t, events = test_events[i])
               for i = 1:length(test_samples)]
ps = [reshape(inp_plots, length(inp_plots), :); 
      reshape(train_plots, length(train_plots), :)]
l = @layout [grid(length(inp_plots), 1) grid(length(inp_plots), 1)]
plot(ps[:]..., layout = l, link = :all, left_margin = 8mm, bottom_margin = 12mm)

In [None]:
savefig("AggLabels.png");