# Spiking multicompartment PC network

## Abstract
Predictive coding is a promising theoretical framework for understanding the hierarchical sensory processing in the brain, yet how it is implemented with cortical spiking neurons is still unclear. While most existing works have taken a hand-wiring approach to creating microcircuits which match experimental results, recent work in applying the optimisation approach revealed that cortical connectivity might result from self-organisation given some fundamental computational principle, ie. energy efficiency. We thus investigated whether predictive coding properties in a multicompartment spiking neural network can result from energy optimisation. We found that only the model trained with an energy objective in addition to a task-relevant objective was able to reconstruct internal representations given top-down expectation signals alone. Neurons in the energy-optimised model also showed differential responses to expected vs unexpected stimuli, qualitatively similar to experimental evidence for predictive coding. These findings indicated that predictive-coding-like behaviour might be an emergent property of energy optimisation, providing a new perspective on how predictive coding could be achieved in the cortex.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision
import numpy as np
from datetime import date
import os
import pandas as pd
import math
import shutil
import matplotlib.pyplot as plt
import seaborn as sns

from predcoding.snn.network import SnnNetwork3Layer
from predcoding.training import train_fptt, get_stats_named_params, reset_named_params
from predcoding.snn.experiments.eval import test
from predcoding.utils import count_parameters, save_checkpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# set seed
torch.manual_seed(999)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
)

batch_size = 200

traindata = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)

testdata = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

# data loading
train_loader = torch.utils.data.DataLoader(
    traindata, batch_size=batch_size, shuffle=False, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
    testdata, batch_size=batch_size, shuffle=False, num_workers=2
)

## Defining the network

In [None]:
# network parameters
adap_neuron = True  # whether use adaptive neuron or not
clf_alpha = 1
energy_alpha = 0.05  # - config.clf_alpha
spike_alpha = 0.0  # energy loss on spikes
num_readout = 10
onetoone = True
lr = 1e-3
alg = "fptt"
dp = 0.4
is_rec = False


# training parameters
T = 50
K = 10  # k_updates is num updates per sequence
omega = int(T / K)  # update frequency
clip = 1.0
log_interval = 20
epochs = 35

In [None]:
# set input and t param
IN_dim = 784
hidden_dim = [600, 500, 500]
n_classes = 10

# define network
model = SnnNetwork3Layer(
    IN_dim,
    hidden_dim,
    n_classes,
    is_adapt=adap_neuron,
    one_to_one=onetoone,
    dp_rate=dp,
    is_rec=is_rec,
)
model.to(device)
print(model)

# define new loss and optimiser
total_params = count_parameters(model)
print("total param count %i" % total_params)

# define optimiser
optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=0.0001)
# reduce the learning after 20 epochs by a factor of 10
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

## Train & test

In [None]:
# untrained network
test_loss, acc1 = test(model, test_loader, T)

In [None]:
named_params = get_stats_named_params(model)
all_test_losses = []
best_acc1 = 20

for epoch in range(epochs):
    train_fptt(
        epoch,
        batch_size,
        log_interval,
        train_loader,
        model,
        named_params,
        T,
        K,
        omega,
        optimizer,
        clf_alpha,
        energy_alpha,
        spike_alpha,
        clip,
        lr,
    )

    reset_named_params(named_params)

    test_loss, acc1 = test(model, test_loader, T)

    scheduler.step()

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

    if is_best:
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                # 'oracle_state_dict': oracle.state_dict(),
                "best_acc1": best_acc1,
                "optimizer": optimizer.state_dict(),
                # 'oracle_optimizer' : oracle_optim.state_dict(),
            },
            is_best,
            prefix="voltage_diff_",
            filename="best.pth.tar",
        )

    all_test_losses.append(test_loss)

## Plotting

In [None]:
def get_states(
    hiddens_all_: list, idx: int, hidden_dim_: int, batch_size, T=20, num_samples=10000
):
    """
    get a particular internal state depending on index passed to hidden

    Args:
        hidden_dim_: the size of a state, eg. num of r or p neurons
        T: total time steps
        hiddens_all_: list containing hidden states of all batch and time steps during inference
        idx: which index in h is taken out
    Returns: 
        np.array containing desired states
    """

    all_states = []

    for batch_idx in range(len(hiddens_all_)):  # iterate over batch
        batch_ = []
        for t in range(T):
            seq_ = []
            for b in range(batch_size):
                seq_.append(hiddens_all_[batch_idx][t][idx][b].detach().cpu().numpy())
            seq_ = np.stack(seq_)
            batch_.append(seq_)
        batch_ = np.stack(batch_)

        all_states.append(batch_)

    all_states = np.stack(all_states)

    return all_states.transpose(0, 2, 1, 3).reshape(num_samples, T, hidden_dim_)

In [None]:
model.eval()
test(model, test_loader, T)

In [None]:
# saved_dict = model_result_dict_load('/content/onelayer_rec_best.pth.tar')
# model.load_state_dict(saved_dict['state_dict'])

In [None]:
# get params and put into dict
param_names_wE = []
param_dict_wE = {}
for name, param in model.named_parameters():
    if param.requires_grad:
        param_names_wE.append(name)

print(param_names_wE)

In [None]:
# clamped generation of internal representations
no_input = torch.zeros((1, IN_dim)).to(device)
clamp_T = T * 5


l1_clamp_E = np.zeros((10, hidden_dim[0]))
l2_clamp_E = np.zeros((10, hidden_dim[1]))
l3_clamp_E = np.zeros((10, hidden_dim[2]))


for i in range(10):
    print(i)
    with torch.no_grad():
        model.eval()

        hidden_i = model.init_hidden(1)

        _, hidden_gen_E_ = model.clamped_generate(i, no_input, hidden_i, clamp_T, clamp_value=1)

        #
        l1_E = get_states([hidden_gen_E_], 1, hidden_dim[0], 1, clamp_T, num_samples=1)
        l2_E = get_states([hidden_gen_E_], 5, hidden_dim[1], 1, clamp_T, num_samples=1)
        l3_E = get_states([hidden_gen_E_], 9, hidden_dim[2], 1, clamp_T, num_samples=1)

        l1_clamp_E[i] += np.squeeze(l1_E.mean(axis=1))
        l2_clamp_E[i] += np.squeeze(l2_E.mean(axis=1))
        l3_clamp_E[i] += np.squeeze(l3_E.mean(axis=1))

    torch.cuda.empty_cache()

In [None]:
##############################################################
# decode from clamped representations
##############################################################
no_input = torch.zeros((1, IN_dim)).to(device)

MSE_loss = nn.MSELoss()

test_loader2 = torch.utils.data.DataLoader(
    testdata, batch_size=batch_size, shuffle=False, num_workers=2
)


# %%
def plot_projection(rep, label, weights, bias):
    img = (weights @ rep + bias).reshape(28, 28)
    plt.imshow(img)
    plt.title(str(label))
    plt.show()
    return img


# %%
layer = 1
l2_E_decoder, loss_E = train_linear_proj(layer, model)

decoders = [l2_E_decoder]

# %%
# plot loss curve of training
colors = [
    (0.1271049596309112, 0.4401845444059977, 0.7074971164936563),
    (0.9949711649365629, 0.5974778931180315, 0.15949250288350636),
]
sns.set_style("whitegrid", {"axes.grid": False})

fig, ax = plt.subplots(figsize=(5, 4))
plt.rcParams.update({"font.size": 14})

ax.plot(loss_E, label="Energy L%i" % (layer + 1), color=colors[0])
ax.legend()
# frame off
ax.spines[["right", "top"]].set_visible(False)
ax.set_ylabel("MES loss")
ax.set_xlabel("steps")
plt.legend(frameon=False)
# increase font size
plt.show()

In [None]:
# plot decoding of clamped internal representations
fig, axes = plt.subplots(1, 10, figsize=(10, 2))

with torch.no_grad():
    for proj_class in range(n_classes):
        img1 = (
            decoders[0](
                torch.tensor(l2_clamp_E[proj_class].astype("float32"))
                .to(device)
                .view(-1, hidden_dim[layer])
            )
            .reshape(28, 28)
            .cpu()
        )
        axes[proj_class].imshow(img1, cmap="viridis")
        axes[proj_class].set_title(str(proj_class))
        # axes[0][proj_class].axis('off')
        axes[proj_class].tick_params(
            left=False, right=False, labelleft=False, labelbottom=False, bottom=False
        )

fig.suptitle("projection from clampled rep back to image plane layer %i" % (layer + 1))
axes[0].set_ylabel("Energy", rotation=0, labelpad=40)

plt.tight_layout()
plt.show()