In [0]:
""" imports """
#Hamiltonian Predictive Coding. Merges MCMC sampling with Predictive coding networks to in theory allow for a biologically plausible way for the brain to achieve MCMC sampling
#with any arbitrary function. The HMC steps are not actually unbiologically plausible at all, and can explain both neural oscillations and inhibitory cell populations beyond that required for error units.
import itertools

import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision

# Functions

In [0]:
""" functions """

def set_tensor(tensor):
    return tensor.to(DEVICE).float()


def tanh(xs):
    return torch.tanh(xs)


def linear(x):
    return x


def tanh_deriv(xs):
    return 1.0 - torch.tanh(xs) ** 2.0


def linear_deriv(x):
    return set_tensor(torch.ones((1,)))


def onehot(x):
    z = np.zeros([10])
    z[x] = 1.0
    return z


def get_batch_size(x_batch=None, y_batch=None):
    """ torch """
    if x_batch is not None:
        return x_batch.size()[1]
    elif y_batch is not None:
        return y_batch.size()[1]
    else:
        return None


def flatten_array(array):
    return torch.flatten(torch.cat(array))


def classification_accuracy(pred_labels, true_labels):
    correct = 0
    batch_size = pred_labels.size()[1]
    for b in range(batch_size):
        if torch.argmax(pred_labels[:, b]) == torch.argmax(true_labels[:, b]):
            correct += 1
    return correct / batch_size


def get_img_list(dataset, n_batches, batch_size):
    arr = [
        np.array(
            [
                np.array(dataset[(n * batch_size) + i][0]).reshape([784, 1]) / 255.0
                for i in range(batch_size)
            ]
        ).T.reshape([784, batch_size])
        for n in range(n_batches)
    ]
    return [set_tensor(torch.from_numpy(d)) for d in arr]


def get_label_list(dataset, n_batches, batch_size):
    arr = [
        np.array([onehot(dataset[(n * batch_size) + i][1]) for i in range(batch_size)]).T
        for n in range(n_batches)
    ]
    return [set_tensor(torch.from_numpy(d)) for d in arr]


def plot_images(imgs):
    imgs = [np.reshape(imgs[:, i], [28, 28]) for i in range(imgs.shape[1])]
    _, axes = plt.subplots(2, 5)
    axes = axes.flatten()
    for i, img in enumerate(imgs):
        axes[i].imshow(img, cmap="gray")
    plt.show()

def symplectic_leapfrog_integrator(mus, momenta, dFdmu, path_len, step_size):
    #the mus and momenta should be batch x feature_size. The momenta are sampled from a gaussian distribution #dfdmu is the gaussian update
    #clone mus and momenta as changing them in the integration step
    mus,momenta = mus.clone(), momenta.clone()
    #take the initial step
    momenta -= step_size * dFdmu / 2  # half step
    #print("symplectic integration steps: ", int(path_len / step_size) - 1)
    for _ in range(int(path_len)):
        #interweave updates for conservative symplectic dynamics
        mus -= step_size * momenta  # whole step
        momenta -= step_size * dFdmu  # whole step
    mus -= step_size * momenta  # whole step
    momenta -= step_size * dFdmu / 2  # half step
    # momentum flip at the end
    return mus, -momenta

def MH_acceptance(F_init, F_end):
  #computes the MH acceptance step in parallel for a batch. Returns a boolean vector 
  #which should determine whether the origin sample is accepted (true) or rejected (false)
  #Interestingly MH acceptance step is determined by the free-energies, which is a nice mathematical connection
  thresh = torch.min(torch.ones_like(F_init), torch.exp(F_init - F_end))
  rand = set_tensor(torch.rand(*F_init.shape))
  return rand <= thresh

# Standard Predictive Coding Layer

In [0]:
""" layers """

class PredictiveCodingLayer(object):
    def __init__(self, input_size, output_size, learning_rate, fn, fn_deriv):
        self.input_size = input_size
        self.output_size = output_size
        self.learning_rate = learning_rate
        self.fn = fn
        self.fn_deriv = fn_deriv
        weights = torch.empty((input_size, output_size)).normal_(mean=0.,std=0.1)
        self.weights = set_tensor(weights)
        self.mu = None

    def reset_mu(self, batch_size):
        mu = torch.empty((self.output_size, batch_size)).normal_(mean=0.,std=1.)
        self.mu = set_tensor(mu)

    def predict(self):
        return self.fn(torch.matmul(self.weights, self.mu))

    def update_mu(self, prior_err, likelihood_err):
        fn_deriv = self.fn_deriv(torch.matmul(self.weights, self.mu))
        delta = torch.matmul(self.weights.T, likelihood_err * fn_deriv)
        delta = -prior_err + delta
        self.mu += self.learning_rate * delta

    def update_weights(self, pred_err):
        fn_deriv = self.fn_deriv(torch.matmul(self.weights, self.mu))
        delta = torch.matmul(pred_err * fn_deriv, self.mu.T)
        self.weights += self.learning_rate * delta

class AmortisedLayer(object):
    def __init__(self, input_size, output_size, learning_rate, fn, fn_deriv):
        self.input_size = input_size
        self.output_size = output_size
        self.learning_rate = learning_rate
        self.fn = fn
        self.fn_deriv = fn_deriv
        weights = torch.empty((input_size, output_size)).normal_(mean=0.,std=0.1)
        self.weights = set_tensor(weights)

    def predict(self, state):
        return self.fn(torch.matmul(self.weights, state))

    def update_weights(self, pred_err, state):
        fn_deriv = self.fn_deriv(torch.matmul(self.weights, state))
        delta = torch.matmul(pred_err * fn_deriv, state.T)
        self.weights += self.learning_rate * delta

# Hamiltonian Layer

In [0]:
""" layers """

class HamiltonianPredictiveCodingLayer(object):
    def __init__(self, input_size, output_size, learning_rate, fn, fn_deriv,N_HMC_samples, integrator_step_size, integrator_path_len,mass_metric_scalar):
        self.input_size = input_size
        self.output_size = output_size
        self.learning_rate = learning_rate
        self.fn = fn
        self.fn_deriv = fn_deriv
        weights = torch.empty((input_size, output_size)).normal_(mean=0.,std=0.1)
        self.weights = set_tensor(weights)
        self.N_HMC_samples = N_HMC_samples
        self.integrator_step_size = integrator_step_size
        self.integrator_path_len = integrator_path_len
        self.mass_metric_scalar = mass_metric_scalar
        self.mu = None

    def reset_mu(self, batch_size):
        mu = torch.empty((self.output_size, batch_size)).normal_(mean=0.,std=1.)
        self.mu = set_tensor(mu)

    def predict(self):
        return self.fn(torch.matmul(self.weights, self.mu))

    def update_mu(self, prior_err, likelihood_err):
        fn_deriv = self.fn_deriv(torch.matmul(self.weights, self.mu))
        delta = torch.matmul(self.weights.T, likelihood_err * fn_deriv)
        delta = -prior_err + delta
        self.mu += self.learning_rate * delta

    def update_weights(self, pred_err):
        fn_deriv = self.fn_deriv(torch.matmul(self.weights, self.mu))
        delta = torch.matmul(pred_err * fn_deriv, self.mu.T)
        self.weights += self.learning_rate * delta

    def predict_with_mu(self,mu):
      return self.fn(torch.matmul(self.weights, mu))

    def predict_deriv(self,mu):
      return self.fn_deriv(torch.matmul(self.weights, mu))


    def PC_HMC_sample_step(self, prev_sample,pred_above,mu_below):
      #compute dF/dmu gradients
      prior_err = prev_sample - pred_above #assumed to be fixed here
      likelihood_err= mu_below - self.predict_with_mu(prev_sample)
      fn_deriv = self.fn_deriv(torch.matmul(self.weights, prev_sample))
      likelihood_back = torch.matmul(self.weights.T, likelihood_err * fn_deriv)
      dFdmu = -prior_err + likelihood_back
      #sample random 
      momenta = set_tensor(torch.empty(prev_sample.shape).normal_(mean=0.0,std=1) * self.mass_metric_scalar)
      new_samp, new_momenta = symplectic_leapfrog_integrator(prev_sample,momenta,dFdmu,self.integrator_path_len,self.integrator_step_size)
      #compute layer-wise free energies for MH acceptance step
      F_old = torch.sum(prior_err * prior_err,dim=0) + torch.sum(likelihood_err*likelihood_err,dim=0)
      F_new = torch.sum((new_samp - pred_above)**2,dim=0) + torch.sum((mu_below - self.predict_with_mu(new_samp))**2,dim=0)
      MH_bools = MH_acceptance(F_old,F_new)
      rejected_idxs = torch.where(~MH_bools)[0] #reverse ordering so it's rejected. Tilde ~ is logical not -- never knew this!
      new_samp[:,rejected_idxs] = prev_sample[:,rejected_idxs]
      return new_samp, new_momenta
    
    def HMC_sampling(self,mu_below,pred_above):
      samples = torch.zeros([self.N_HMC_samples, *self.mu.shape])
      samples[0] = self.mu.clone()
      for n in range(1,self.N_HMC_samples):
        new_samp, new_momenta = self.PC_HMC_sample_step(samples[n-1],mu_below,pred_above)
        samples[n,:,:] = new_samp
      return samples

class AmortisedLayer(object):
    def __init__(self, input_size, output_size, learning_rate, fn, fn_deriv):
        self.input_size = input_size
        self.output_size = output_size
        self.learning_rate = learning_rate
        self.fn = fn
        self.fn_deriv = fn_deriv
        weights = torch.empty((input_size, output_size)).normal_(mean=0.,std=0.1)
        self.weights = set_tensor(weights)

    def predict(self, state):
        return self.fn(torch.matmul(self.weights, state))

    def update_weights(self, pred_err, state):
        fn_deriv = self.fn_deriv(torch.matmul(self.weights, state))
        delta = torch.matmul(pred_err * fn_deriv, state.T)
        self.weights += self.learning_rate * delta

# Network

In [0]:
class HamiltonianNetwork(object):
    def __init__(
        self,
        layer_sizes,
        v_learning_rate,
        q_learning_rate,
        fn,
        fn_deriv,
        threshold,
        N_HMC_samples,
        integrator_step_size,
        integrator_path_len,
        mass_metric_scalar
    ):
        self.layer_sizes = layer_sizes
        self.v_learning_rate = v_learning_rate
        self.q_learning_rate = q_learning_rate
        self.n_infer_steps_train = n_infer_steps_train
        self.n_infer_steps_test = n_infer_steps_test
        self.fn = fn
        self.fn_deriv = fn_deriv
        self.n_activations = len(layer_sizes)
        self.n_layers = len(layer_sizes) - 1
        self.threshold = threshold
        self.N_HMC_samples = N_HMC_samples,
        self.integrator_step_size = integrator_step_size
        self.integrator_path_len = integrator_path_len
        self.mass_metric_scalar = mass_metric_scalar

        self.build()

    def build(self):
        self.v_layers = []
        self.q_layers = []
        for i in range(self.n_layers):
            self.v_layers.append(
                HamiltonianPredictiveCodingLayer(
                    input_size=self.layer_sizes[i],
                    output_size=self.layer_sizes[i + 1],
                    learning_rate=self.v_learning_rate,
                    fn=self.fn,
                    fn_deriv=self.fn_deriv,
                    N_HMC_samples = self.N_HMC_samples,
                    integrator_step_size = self.integrator_step_size,
                    integrator_path_len = self.integrator_path_len,
                    mass_metric_scalar=self.mass_metric_scalar
                )
            )
            self.q_layers.append(
                AmortisedLayer(
                    input_size=self.layer_sizes[i + 1],
                    output_size=self.layer_sizes[i],
                    learning_rate=self.q_learning_rate,
                    fn=self.fn,
                    fn_deriv=self.fn_deriv,
                )
            )

    def reset(self, batch_size):
        self.v_preds = [[] for i in range(self.n_activations)]
        self.v_pred_errs = [[] for i in range(self.n_activations)]
        self.q_preds = [[] for i in range(self.n_activations)]
        self.q_pred_errs = [[] for i in range(self.n_activations)]

        for layer in self.v_layers:
            layer.reset_mu(batch_size)
            
    def set_input(self, x_batch=None, y_batch=None):
        if x_batch is not None:
            self.q_preds[0] = x_batch.clone()
        if y_batch is not None:
            self.v_layers[-1].mu = y_batch.clone()

    def infer(self, n_steps, x_batch=None, y_batch=None,use_sampling=False):
        """
        Run inference with both amortised and variational networks 
        """
        batch_size = get_batch_size(x_batch=x_batch, y_batch=y_batch)
        self.reset(batch_size)
        self.set_input(x_batch, y_batch)
        self.amortised_forward(set_mu=True)
        self.variational_backward()
        n_iterations = self.variational_updates(n_steps, x_batch,use_sampling=use_sampling)
        return {'n_iterations': n_iterations}

    def variational_infer(self, n_steps, x_batch=None, y_batch=None):
        """
        Run inference with variational network
        """
        batch_size = get_batch_size(x_batch=x_batch, y_batch=y_batch)
        self.reset(batch_size)
        self.set_input(x_batch, y_batch)
        self.variational_backward()
        n_iterations = self.variational_updates(n_steps, x_batch)
        return n_iterations

    def amortised_infer(self, x_batch):
        """
        Run inference with amortised network
        """
        batch_size = get_batch_size(x_batch=x_batch)
        self.reset(batch_size)
        self.set_input(x_batch)
        self.amortised_forward(set_mu=False)

    def amortised_forward(self, set_mu=False):
        """
        Forward pass of amortised network
        """
        for i in range(self.n_layers):
            self.q_preds[i + 1] = self.q_layers[i].predict(self.q_preds[i])
            if set_mu:
                self.v_layers[i].mu = self.q_preds[i + 1].clone()

    def variational_backward(self):
        """
        Backward pass of variational network
        """
        for i in reversed(range(self.n_layers)):
            self.v_preds[i] = self.v_layers[i].predict()

    def variational_updates(self, n_steps, x_batch=None,use_sampling=False):
        """
        Variational updates
        """
        n_iterations = 0
        for _ in range(n_steps):
            zeros = set_tensor(torch.zeros_like(self.v_preds[0]))
            data_err = (zeros if x_batch is None else (x_batch - self.v_preds[0]))
            self.v_pred_errs[0] = data_err
            self.v_preds[-1] = set_tensor(torch.zeros_like(self.v_layers[-1].mu))

            for i in range(self.n_layers):
                self.v_pred_errs[i + 1] = self.v_layers[i].mu - self.v_preds[i + 1]
                likelihood_err = self.v_pred_errs[i]
                prior_err = self.v_pred_errs[i + 1]
                if i == self.n_layers - 1:
                    prior_err = set_tensor(torch.zeros_like(self.v_layers[i].mu))
                if use_sampling==True:
                  if i == 0:
                    mu_below = x_batch
                  else:
                    mu_below = self.v_layers[i-1].mu
                  new_mu, new_momenta = self.v_layers[i].PC_HMC_sample_step(self.v_layers[i].mu, self.v_preds[i+1],mu_below)
                  self.v_layers[i].mu =  new_mu.clone()
                else:
                  self.v_layers[i].update_mu(prior_err, likelihood_err)
                self.v_preds[i] = self.v_layers[i].predict()

            n_iterations += 1
            flat_array = flatten_array(self.v_pred_errs)
            if torch.abs(flat_array).mean() < self.threshold:
                break

        return n_iterations

    def variational_learn(self):
        """
        Update weights of variational network
        """
        for i in range(self.n_layers):
            self.v_layers[i].update_weights(self.v_pred_errs[i])

    def amortised_learn(self):
        """
        Update weights of amortised network
        """
        for i in range(self.n_layers):
            self.q_pred_errs[i] = self.v_layers[i].mu - self.q_preds[i + 1]
            state = self.q_preds[0] if i == 0 else self.v_layers[i - 1].mu
            self.q_layers[i].update_weights(self.q_pred_errs[i], state)

In [0]:
""" network functions """

def test_accuracy(model, x_list, y_list, n_infer_steps,use_sampling=False):
    """
    Test accuracy of full model
    """
    accuracy = 0
    for (x_batch, y_batch) in zip(x_list, y_list):
        model.infer(n_infer_steps, x_batch=x_batch,use_sampling=use_sampling)
        pred_y = model.v_layers[-1].mu
        accuracy += classification_accuracy(pred_y, y_batch)
    return accuracy / len(x_list)

def test_variational_accuracy(model, x_list, y_list, n_infer_steps):
    """
    Test accuracy of variational network
    """
    accuracy = 0
    for (x_batch, y_batch) in zip(x_list, y_list):
        model.variational_infer(n_infer_steps_test, x_batch=x_batch)
        pred_y = model.v_layers[-1].mu
        accuracy += classification_accuracy(pred_y, y_batch)
    return accuracy / len(x_list)


def test_amortised_accuracy(model, x_list, y_list):
    """
    Test accuracy of amortised network
    """
    accuracy = 0
    for (x_batch, y_batch) in zip(x_list, y_list):
        model.amortised_infer(x_batch)
        pred_y = model.q_preds[-1]
        accuracy += classification_accuracy(pred_y, y_batch)
    return accuracy / len(x_list)

def train_batch(model, x_batch, y_batch, n_infer_steps,use_sampling=False):
    """
    Train network on single batch of data
    """
    info = model.infer(n_infer_steps, x_batch=x_batch, y_batch=y_batch,use_sampling=use_sampling)
    model.v_layers[-1].mu = y_batch.clone()
    model.v_preds[-1] = y_batch.clone()
    model.variational_learn()
    #model.amortised_learn()
    return info

In [45]:
""" run experiment """

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_epochs = 100
batch_size = 128
#n_batches = 400
n_batches = 5
n_test_batches = 10

n_infer_steps_train = 100
n_infer_steps_test = 1000

v_learning_rate = 0.005
q_learing_rate = 0.001
layer_sizes = [784, 300, 100, 10]

#threshold = 0.15
threshold = 0.01
test_every = 1

train_set = torchvision.datasets.MNIST("MNIST_train", download=True, train=True)
test_set = torchvision.datasets.MNIST("MNIST_train", download=True, train=False)

img_list = get_img_list(train_set, n_batches, batch_size)
label_list = get_label_list(train_set, n_batches, batch_size)

test_img_list = get_img_list(test_set, n_test_batches, batch_size)
test_label_list = get_label_list(test_set, n_test_batches, batch_size)

model = HamiltonianNetwork(layer_sizes=layer_sizes,
                v_learning_rate=v_learning_rate,
                q_learning_rate=q_learing_rate,
                fn=tanh,
                fn_deriv=tanh_deriv,
                threshold=threshold,
                N_HMC_samples=1000,
                integrator_step_size=0.001,
                integrator_path_len=100,
                mass_metric_scalar=1)

with torch.no_grad():
    n_iterations = []
    accuracy = []
    q_accuracy = []
    for epoch in range(n_epochs):
        print(f"> Training Epoch {epoch}")
        for (x_batch, y_batch) in zip(img_list, label_list):
            info = train_batch(model, x_batch, y_batch, n_infer_steps_train)
            n_iterations.append(info['n_iterations'])

        if epoch % test_every == 0:
            ep_acc = test_accuracy(model, test_img_list, test_label_list, n_infer_steps_test,use_sampling=True)
            accuracy.append(ep_acc)
            q_ep_acc = test_amortised_accuracy(model, test_img_list, test_label_list)
            q_accuracy.append(q_ep_acc)
            print(f"Accuracy: {ep_acc}")
            print(f"Amortised Accuracy: {q_ep_acc}")
            print(f"Iterations: {info['n_iterations']}")
            np.save("accuracy.npy" , np.array(accuracy))
            np.save("q_accuracy.npy" , np.array(q_accuracy))
            np.save("n_iterations.npy" , np.array(n_iterations))

> Training Epoch 0
Accuracy: 0.1859375
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 1
Accuracy: 0.29296875
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 2
Accuracy: 0.3515625
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 3
Accuracy: 0.38671875
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 4
Accuracy: 0.4125
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 5
Accuracy: 0.3984375
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 6
Accuracy: 0.42421875
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 7
Accuracy: 0.43203125
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 8
Accuracy: 0.40859375
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 9
Accuracy: 0.43515625
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 10
Accuracy: 0.43203125
Amortised Accuracy: 0.11796875
Iterations: 100
> Training Epoch 11
Accuracy: 0.4484375
Amortised

KeyboardInterrupt: ignored