# Implementation of zero-divergence Inference Learning in a Predictive Coding Network

## Predictive Coding Network

A predictive coding network is a probabilistic model that calculates.

Variables on adjacent levels are assumed to be related by

$$ P(x_i^l | \bar x^{l+1}) = \mathcal{N}( x_i^l; \mu_i^l, \Sigma_i^l) $$

where

$$ \mu_i^l = {\theta_i^l}^T f(\bar x^{l+1}) $$

with the objective being to maximize

$$ F = \ln P(\bar x^1,...,\bar x^{L-1} | \bar x^L) $$

due to the assumed relationship between adjacent layers this simplifies to

$$ \begin{align*}
    F &= \sum_{l=0}^{L-1} \ln P(\bar x^l | \bar x^{l+1})    \\
    &= \sum_{l=0}^{L-1} \sum_{i=1}^n \ln \mathcal{N}( x_i^l; \mu_i^l, \Sigma_i^l)    \\
    &= \sum_{l=0}^{L-1} \sum_{i=1}^n \ln \frac{1}{\sqrt{2\pi\Sigma_i^l}} - \frac{1}{2}\frac{(x_i^l - \mu_i^l)^2}{\Sigma_i^l}
\end{align*} $$

ignoring the constant term (since we are going to use the derivative with respect to $x_i^l$)

$$ F = -\frac{1}{2} \sum_{l=0}^{L-1} \sum_{i=1}^n \frac{(x_i^l - \mu_i^l)^2}{\Sigma_i^l} $$

In this model we will assume the variances to be 1, and letting $\epsilon_i^l = x_i^l - \mu_i^l$

$$ F = -\frac{1}{2} \sum_{l=0}^{L-1} \sum_{i=1}^n (\epsilon_i^l)^2 $$

to update each $x_i$ we will use the partial derivative of $F$ with respect to $x_i$

$$ \frac{\partial F}{x_i^l} = -\epsilon_i^l + f'(x_i^l) \sum_{k=1}^n \theta_{i,k}^l \epsilon_k^{l-1}$$

updating the weights
$$ \frac{\partial F}{\theta_{i,j}^l} = \epsilon_i^{l-1} f(x_j^l) $$

In [4]:
import matplotlib.pyplot as plt
import numpy as np
from PredCodMLP import PredCodMLP

In [2]:
from read_image import *

train_images = read_mnist_images('data/train-images-idx3-ubyte.gz')
train_labels = read_mnist_labels('data/train-labels-idx1-ubyte.gz')
test_images = read_mnist_images('data/t10k-images-idx3-ubyte.gz')
test_labels = read_mnist_labels('data/t10k-labels-idx1-ubyte.gz')

num_train, H, W = train_images.shape
num_test, _, _ = test_images.shape
train_images = train_images.reshape(num_train, 1, H, W)
test_images = test_images.reshape(num_test, 1, H, W)
print("num_train: ", num_train)
print("num_test: ", num_test)

num_train:  60000
num_test:  10000


In [6]:
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K,D)) # data matrix (each row = single example)
y = np.zeros(N*K, dtype='uint8') # class labels
for j in range(K):
  ix = range(N*j,N*(j+1))
  r = np.linspace(0.0,1,N) # radius
  t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*0.2 # theta
  X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  y[ix] = j
# lets visualize the data:
#plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.Spectral)
#plt.show()

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap, random, value_and_grad, nn
from typing import Tuple
from adam import adam, make_default_adam_config

def cross_entropy_loss(scores: jnp.ndarray, target: jnp.ndarray):
    scores -= scores.max()
    scores = jnp.exp(scores)
    probs = scores / scores.sum()
    return -jnp.log(probs[target])

def forward(params: list[Tuple], input: jnp.ndarray, target: jnp.ndarray):
    for W, b in params:
        output = input.dot(W) + b
        input = nn.relu(output)
    return cross_entropy_loss(output, target)

def batched_forward(params, X, y):
    return jnp.mean(vmap(forward, in_axes=[None, 0, 0])(params, X, y))

def init_layer(Din, Dout, key):
    return (jnp.sqrt(2/(Din*Dout)) * random.normal(key, (Din, Dout)), jnp.zeros(Dout))

@jit
def update(params, config) -> list[Tuple]:
    loss, grads = value_and_grad(batched_forward)(params, X, y)
    new_params = []
    new_config = []
    for l in range(2):
        W, W_config = adam(params[l][0], grads[l][0], config[l][0])
        b, b_config = adam(params[l][1], grads[l][1], config[l][1])
        new_params.append((W,b))
        new_config.append((W_config, b_config))
    return new_params, new_config, loss


key = random.PRNGKey(0)
key, *subkeys = random.split(key, 3)
params = [init_layer(D, 100, subkeys[0]), init_layer(100, K, subkeys[1])]

"""
W1 = make_default_adam_config(params[0][0])
b1 = make_default_adam_config(params[0][1])
W2 = make_default_adam_config(params[1][0])
b2 = make_default_adam_config(params[1][1])
config = [(W1,b1),(W2,b2)]
"""
config = [({},{}),({},{})]
for t in range(100):
    params, config, loss = update(params, config)
    print(loss)

In [4]:
from PredCodMLP import cross_entropy_loss

num_train, C, H, W = train_images.shape
K = 10
batch_size = 256
test_batch = np.array(test_images, dtype=np.float32).reshape(num_test, -1)
test_batch /= (255/2)
test_batch -= 1

predcod = PredCodMLP([H*W, 800, K])
for t in range(1000):
    batch_mask = np.random.choice(num_train, batch_size)
    batch = np.array(train_images[batch_mask], dtype=np.float32)
    batch = batch.reshape(batch_size, -1)
    batch /= (255/2)
    batch -= 1
    labels = train_labels[batch_mask]
    predcod.train_step(batch, labels, 1e-3)
    if t % 100 == 0:
        loss, _ = cross_entropy_loss(predcod.predict(test_batch), test_labels)
        print(loss)
#classes = np.argmax(preds, 1)
#print(np.mean(y.flatten() == classes))



: 

In [None]:
%cd fast_layers
%conda activate predcod
!python setup.py build_ext --inplace
%cd ..

In [19]:
import cnn
from cnn import CNN
from adam import adam
import jax.numpy as jnp
from jax import random, jit

num_train, C, H, W = train_images.shape
batch_size = 512
num_filters = 25
num_classes = 10
key = random.PRNGKey(0)
key, *cnnkeys = random.split(key, 3)
model = CNN(batch_size, C, H, W, num_filters, num_classes, cnnkeys)
config = { 'W1': {}, 'b1': {}, 'W2': {}, 'b2': {} }


@jit
def update(params, inputs, labels, config: dict):
    loss, dscores, cache = cnn.loss(params, inputs, labels)
    grads = cnn.backward(dscores, cache)
    new_params = {}
    for k, v in params:
        new_params[k], config[k] = adam(params[k], grads[k], config[k])
    return new_params, config, loss

for t in range(1000):
    key, batch_key = random.split(key)
    batch_mask = np.random.choice(num_train, batch_size)
    batch = np.array(train_images[batch_mask], dtype=np.float32)
    batch /= (255/2)
    batch -= 1
    labels = train_labels[batch_mask]
    model.params, config, loss = update(model.params, batch, labels, config)
    if t % 100 == 0:
        batch_mask = np.random.choice(num_train, 2**12)
        batch = np.array(train_images[batch_mask], dtype=np.float32)
        batch_labels = train_labels[batch_mask]
        #batch = np.array(train_images, dtype=np.float32)
        batch /= (255/2)
        batch -= 1
        scores, _ = cnn.forward(model.params, batch)
        train_acc = np.mean(np.argmax(scores, 1) == batch_labels)
        print("t: ", t, "training accuracy: ", train_acc)
        scores, _ = cnn.forward(test_images)
        test_acc = np.mean(np.argmax(scores, 1) == test_labels)
        print("validation accuracy: ", test_acc)

AttributeError: 'ShapedArray' object has no attribute 'copy'

In [6]:
"""
scores, _ = cnn.forward(train_images)
train_acc = np.mean(np.argmax(scores, 1) == train_labels)
print("training accuracy: ", train_acc)
"""
scores, _ = cnn.forward(test_images)
test_acc = np.mean(np.argmax(scores, 1) == test_labels)
print("validation accuracy: ", test_acc)

