# Session 6: Goal-Driven Deep Learning

## About this tutorial
In the lecture you have seen how deep learning can be used to model brain regions and functions. This is motivated by the fact that deep neural networks (DNNs) combine brain inspired computational principles with hitherto unseen effectiveness at solving perceptual and motor tasks. This renders DNNs well suited to uncover the kinds of representations and computations that may underlie complex, high-level functions of biological systems. As such, biologically plausible DNNs can be used as generative models to formulate new hypothesis about brain functionality. Furthermore, DNNs can be used for testing hypotheses in neuroscience in silico by training them on ecologically relevant tasks and subsequently exposing them to stimuli used in neuroscientific experimentation. However, while DNNS are biologically inspired, they are not yet particularly biologically realistic. 

In this tutorial we will address one shortcoming of DNNs, their activation functions are not particularly biologically plausible. Real neurons use discrete spikes whereas DNN units have continuous activation functions. Given what you learned last week, it should nevertheless be possible to train a DNN with a spiking-neuron activation function. This is what you will do in this tutorial. This tutorial constitutes your third formative assessment and you have time to finish it until 6pm on October 11. A solution will become available after that deadline.

## Spiking Activation Function
To make a deep neural network spiking, we will utilize an approach put forth by [Hunsberger & Eliasmith (2015)](https://arxiv.org/abs/1510.08829). These authors used the steady-state firing rate of leaky integrate-and-fire (LIF) neurons as activation functions in their neural networks. The steady state firing rate of a LIF neuron can be derived analytically. We start with a simplified neuron model (reversal potential is set to $0$ & conductance is set to $1$):

$$
\tau_m \dot{V} = -V + I
$$

where $\tau_m$ is the membrane time constant and $I$ is a constant input current, the steady-state firing rate is given by

$$
G \left( I \right) = \begin{cases}
        \left[ \tau_{ref} - \tau_m \ln \left( 1 - \frac{V_{thr}}{I} \right) \right]^{-1}, & \text{if } I > V_{thr}\\
        0 & \text{otherwise}
        \end{cases}
$$

where $\tau_{ref}$ is the refactory period and $V_the$ is the threshold. 

The LIF steady state firing rate has the particular problem that its derivative approaches infinity as $I$ approaches zero from above. This causes problems when employing backpropagation but can be addressed by slightly adjusting the firing rate equation to smooth it out. The equation above can be re-written like this:

$$
G \left( I \right) = \left[ \tau_{ref} + \tau_m \ln \left( 1 + \frac{V_{thr}}{\rho \left( I - V_{thr} \right)} \right) \right]^{-1}
$$

where $\rho (x) = \max(x,0)$. If we replace this hard maximum with a softer maximum $\rho (x) = \ln\left( 1 + e^x  \right)$, then the LIF neuron loses its hard threshold and the derivative becomes bounded. This will be the LIF-based activation function for our spiking DNN.

## Deep Learning with PyTorch

Before you can start building and training a spiking DNN, you should first have a look at how to build and train any neural network using machine learning libraries like [PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/). In the present session, we will use the former. 

In this section, we run through the API for common tasks in deep learning. 

### Working with data

PyTorch has two primitives to work with data: `torch.utils.data.DataLoader` and `torch.utils.data.Dataset`. `Dataset ` stores the samples and their corresponding labels, and `DataLoader` wraps an iterable around the `Dataset`.

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch offers domain-specific libraries which include datasets. For this tutorial, we will be using the `TorchVision` dataset library.

The `torchvision.datasets` module contains `Dataset` objects for many real-world vision data like CIFAR. Here, we use the FashionMNIST dataset. Every TorchVision `Dataset` includes two arguments: `transform` and `target_transform` to modify the samples and labels respectively.

In [2]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

We pass the `Dataset` as an argument to `DataLoader`. This wraps an iterable over our dataset, and supports automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of $64$, i.e. each element in the dataloader iterable will return a batch of $64$ features and labels.

In [3]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


### Creating Models

To define a neural network in PyTorch, we create a class that inherits from `nn.Module`. We define the layers of the network in the `__init__` function and specify how data will pass through the network in the forward function.

In [4]:
# Set cpu device for training.
device = "cpu"
print(f"Using {device} device")

# Define model
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv_stack1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(7, 7), padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        self.conv_stack2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(7, 7),padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        self.linear_stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(800, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):

        conv1 = self.conv_stack1(x)
        conv2 = self.conv_stack2(conv1)
        logits = self.linear_stack(conv2)
        return logits

model = ConvNet().to(device)
print(model)


Using cpu device
ConvNet(
  (conv_stack1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv_stack2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(7, 7), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (linear_stack): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=800, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
)


Feel free to play around with the number of channels, filter sizes etc. Make sure, however, that things fit together. For instance, the number of inputs to the first *fully connected* (`Linear`) layer was set to $800$ because $32$ channels $\times$ $5$ pixels $\times$ $5$ pixels $=800$. How did we get to $5 \times 5$ pixels? Note that we have two steps where we have convolution followed by max pooling. There is a neat equation to find the image dimensions after a convolution or pooling operation:

$$
\begin{array}{lr}
W_o &= \lfloor \frac{W_i - F_W + 2 P_W}{S_W} \rfloor \\
H_o &= \lfloor \frac{H_i - F_H + 2 P_H}{S_H} \rfloor
\end{array}
$$

where $W$ and $H$ stand for width and height, respectively; the subscripts $i$ and $o$ reflect input and output dimensions, respectively; $F$ is the filter (kernel) size; $P$ is the zero-padding; $S$ is the stride; and $\lfloor \cdot \rfloor$ denotes the `floor` operation (rounding down). In our example, we are dealing with square images (i.e., $W=H$).

Let's trace the sizes of our layers through the two convolution stacks. 

#### Conv Stack 1
We start with a $28 \times 28$ gray-scale image such that the input image dimensions are $28 \times 28 \times 1$. Then we apply convolutions with $16$ different square filters, each of size $F=7$, zero-padding $P=3$ and no stride $(S=1)$. This leads to a dimension of the first convolutional layer of $28 \times 28 \times 16$.

Then, we apply a pooling operation with $F=2$ (`kernel_size`), $P=0$ and $S=2$. This leads to the following dimension of the pooling layer $14 \times 14 \times 16$.

#### Conv Stack 2
Then, we apply another convolution operation to the pooling layer of the first stack. We apply convolutions with $32$ different square filters, each of size $F=7$, zero-padding $P=1$ and no stride $(S=1)$. This leads to a dimension of the first convolutional layer of $10 \times 10 \times 32$.

Finally, we apply another pooling operation with $F=2$ (`kernel_size`), $P=0$ and $S=2$. This leads to the following dimension of the second pooling layer $5 \times 5 \times 32$.


### Optimizing the Model Parameters

To train a model, we need a loss function and an optimizer. I use the stochastic gradient descent (`SGD`). This is not optimal (`Adam` is much better) but useful to illustrate how the network gradually improves with training. Feel free to play around with other optimizers after you finished the session.

In [5]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model’s parameters.

In [6]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

We also check the model’s performance against the test dataset to ensure it is learning.

In [7]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

The training process is conducted over several iterations (*epochs*). During each epoch, the model learns parameters to make better predictions. We print the model’s accuracy and loss at each epoch; we’d like to see the accuracy increase and the loss decrease with every epoch.

In [8]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.298368  [    0/60000]
loss: 2.301712  [ 6400/60000]
loss: 2.297616  [12800/60000]
loss: 2.296241  [19200/60000]
loss: 2.288442  [25600/60000]
loss: 2.295777  [32000/60000]
loss: 2.290642  [38400/60000]
loss: 2.286241  [44800/60000]
loss: 2.275732  [51200/60000]
loss: 2.279243  [57600/60000]
Test Error: 
 Accuracy: 32.8%, Avg loss: 2.272157 

Epoch 2
-------------------------------
loss: 2.270475  [    0/60000]
loss: 2.272914  [ 6400/60000]
loss: 2.258083  [12800/60000]
loss: 2.257253  [19200/60000]
loss: 2.240810  [25600/60000]
loss: 2.226688  [32000/60000]
loss: 2.222336  [38400/60000]
loss: 2.183838  [44800/60000]
loss: 2.166270  [51200/60000]
loss: 2.130146  [57600/60000]
Test Error: 
 Accuracy: 48.9%, Avg loss: 2.103720 

Epoch 3
-------------------------------
loss: 2.122261  [    0/60000]
loss: 2.075302  [ 6400/60000]
loss: 1.937434  [12800/60000]
loss: 1.874486  [19200/60000]
loss: 1.662032  [25600/60000]
loss: 1.528823  [32000/600

### Define Custom Activation Function
To eventually make the CNN spiking, you cannot use any of the standard activation functions. Instead, you need to define your own activation function based on the F-I (frequency-current) curve of the leaky integrate and fire neuron. Fortunately, defining a custom activation function for PyTorch is (*almost*) as a simple as defining a Python function. The only additional step is to create a class wrapper from PyTorch `nn.Module` to make sure PyTorch can use the function. Also, make sure to use `torch` (rather than e.g. NumPy) operations in your function definition.

As an example, we wil implement the sigmoid linear unit (SiLU):

$$
f(x) = \frac{x}{1+e^{-x}} = x \cdot \sigma (x)
$$

where $\sigma (\cdot)$ is the sigmoid activation function.


In [9]:
# activation function
def silu(x):
    '''
    Applies the Sigmoid Linear Unit (SiLU) function element-wise
    '''
    return torch.sigmoid(x) * x


# class wrapper
class SiLU(nn.Module):
    '''
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    '''
    def __init__(self):
        super().__init__() # init the base class

    def forward(self, x):
        return silu(x)

The code below utilizes this function in a CNN.

In [10]:
# Define model
class SiLUNet(nn.Module):
    def __init__(self):
        super(SiLUNet, self).__init__()
        self.silu = SiLU()
        self.conv_stack1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(7, 7), padding=3),
            self.silu,
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        self.conv_stack2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(7, 7),padding=1),
            self.silu,
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        self.linear_stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(800, 512),
            self.silu,
            nn.Linear(512, 10)
        )

    def forward(self, x):

        conv1 = self.conv_stack1(x)
        conv2 = self.conv_stack2(conv1)
        logits = self.linear_stack(conv2)
        return logits

silu_model = SiLUNet().to(device)
print(silu_model)

SiLUNet(
  (silu): SiLU()
  (conv_stack1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): SiLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv_stack2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(7, 7), stride=(1, 1), padding=(1, 1))
    (1): SiLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (linear_stack): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=800, out_features=512, bias=True)
    (2): SiLU()
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
)


Let's train our newly defined `SiLUNet` on the same classification task as before.

In [11]:
silu_optimizer = torch.optim.SGD(silu_model.parameters(), lr=1e-3)

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, silu_model, loss_fn, silu_optimizer)
    test(test_dataloader, silu_model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.307589  [    0/60000]
loss: 2.304374  [ 6400/60000]
loss: 2.303083  [12800/60000]
loss: 2.302176  [19200/60000]
loss: 2.298456  [25600/60000]
loss: 2.302263  [32000/60000]
loss: 2.302528  [38400/60000]
loss: 2.303756  [44800/60000]
loss: 2.305499  [51200/60000]
loss: 2.299286  [57600/60000]
Test Error: 
 Accuracy: 10.8%, Avg loss: 2.300229 

Epoch 2
-------------------------------
loss: 2.304034  [    0/60000]
loss: 2.301336  [ 6400/60000]
loss: 2.299125  [12800/60000]
loss: 2.298839  [19200/60000]
loss: 2.295588  [25600/60000]
loss: 2.297917  [32000/60000]
loss: 2.299269  [38400/60000]
loss: 2.299597  [44800/60000]
loss: 2.301291  [51200/60000]
loss: 2.295222  [57600/60000]
Test Error: 
 Accuracy: 26.9%, Avg loss: 2.295960 

Epoch 3
-------------------------------
loss: 2.299838  [    0/60000]
loss: 2.297535  [ 6400/60000]
loss: 2.294022  [12800/60000]
loss: 2.294523  [19200/60000]
loss: 2.291507  [25600/60000]
loss: 2.291777  [32000/600

## Task 1 - (50 points)
Create a custom LIF activation function. Be careful to include all neuron parameters required for the activation function. For numerical stability, it might be helpful to add a small value (e.g. $10^{-20}$) to the $\rho ( \cdot )$ function.

$$
\left[ \begin{array}{lr}
\tau_m & = 0.02 \text{s} \\
\tau_{ref} & = 0.004 \text{s} \\
V_{thr} & = 1 \text{mV}
\end{array} \right]
$$


In [12]:
# activation function
def lif(x):
    '''
    Applies the LIF activation function element-wise
    '''
    tau_m = 0.02
    tau_ref = 0.004
    V_thr = 1
    gamma = 0.15
    
    rho = lambda x: gamma * torch.log( 1 + torch.exp(x / gamma) ) + 1e-20 
    
    return torch.pow(tau_ref + tau_m * torch.log( 1 + torch.div(V_thr, rho(x - V_thr) ) ), -1)


# class wrapper
class LIF(nn.Module):
    '''
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    '''
    def __init__(self):
        super().__init__() # init the base class

    def forward(self, x):
        return lif(x)

## Task 2 - (50 points)
Create a convolutional neural network that utilizes the LIF activation function for the **first** convolutional layer and train the network to classify the fashion MNIST data set. 

Note that ht LIF activation function can produce quite large values. This is the reason why we only use it in the first convolutional layer (you can choose either the ReLU or SiLU for the other layers) $\rightarrow$ gradient descent becomes unstable with too many LIF neurons! Even a single layer with a LIF activation function will render this network harder to train than the ones before, so you likely have to adjust the learning rate.

In [15]:
# Define model
class LIFNet(nn.Module):
    def __init__(self):
        super(LIFNet, self).__init__()
        self.lif = LIF()
        self.conv_stack1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(7, 7), padding=3),
            self.lif,
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        self.conv_stack2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(7, 7),padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        self.linear_stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(800, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
            
    def forward(self, x):

        conv1 = self.conv_stack1(x)
        conv2 = self.conv_stack2(conv1)
        logits = self.linear_stack(conv2)
        return logits


In [19]:
lif_model = LIFNet().to(device)

lif_optimizer = torch.optim.SGD(lif_model.parameters(), lr=1e-5)  # LR = 1e-5 seems to lead to "decent" performance but is still somewhat high (occasionally causes stability issues)

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, lif_model, loss_fn, lif_optimizer)
    test(test_dataloader, lif_model, loss_fn)
print("Done!")


Epoch 1
-------------------------------
loss: 2.326438  [    0/60000]
loss: 2.329180  [ 6400/60000]
loss: 2.291061  [12800/60000]
loss: 2.336880  [19200/60000]
loss: 2.336195  [25600/60000]
loss: 2.255900  [32000/60000]
loss: 2.321995  [38400/60000]
loss: 2.285791  [44800/60000]
loss: 2.270254  [51200/60000]
loss: 2.301459  [57600/60000]
Test Error: 
 Accuracy: 17.2%, Avg loss: 2.264489 

Epoch 2
-------------------------------
loss: 2.252483  [    0/60000]
loss: 2.274608  [ 6400/60000]
loss: 2.248017  [12800/60000]
loss: 2.285228  [19200/60000]
loss: 2.258527  [25600/60000]
loss: 2.229614  [32000/60000]
loss: 2.254202  [38400/60000]
loss: 2.247846  [44800/60000]
loss: 2.235652  [51200/60000]
loss: 2.247137  [57600/60000]
Test Error: 
 Accuracy: 31.7%, Avg loss: 2.226393 

Epoch 3
-------------------------------
loss: 2.220045  [    0/60000]
loss: 2.245704  [ 6400/60000]
loss: 2.216213  [12800/60000]
loss: 2.257740  [19200/60000]
loss: 2.214200  [25600/60000]
loss: 2.198403  [32000/600