# Implementation of zero-divergence Inference Learning in a Predictive Coding Network

## Predictive Coding Network

A predictive coding network is a probabilistic model that calculates.

Variables on adjacent levels are assumed to be related by

$$ P(x_i^l | \bar x^{l+1}) = \mathcal{N}( x_i^l; \mu_i^l, \Sigma_i^l) $$

where

$$ \mu_i^l = {\theta_i^l}^T f(\bar x^{l+1}) $$

with the objective being to maximize

$$ F = \ln P(\bar x^1,...,\bar x^{L-1} | \bar x^L) $$

due to the assumed relationship between adjacent layers this simplifies to

$$ \begin{align*}
    F &= \sum_{l=0}^{L-1} \ln P(\bar x^l | \bar x^{l+1})    \\
    &= \sum_{l=0}^{L-1} \sum_{i=1}^n \ln \mathcal{N}( x_i^l; \mu_i^l, \Sigma_i^l)    \\
    &= \sum_{l=0}^{L-1} \sum_{i=1}^n \ln \frac{1}{\sqrt{2\pi\Sigma_i^l}} - \frac{1}{2}\frac{(x_i^l - \mu_i^l)^2}{\Sigma_i^l}
\end{align*} $$

ignoring the constant term (since we are going to use the derivative with respect to $x_i^l$)

$$ F = -\frac{1}{2} \sum_{l=0}^{L-1} \sum_{i=1}^n \frac{(x_i^l - \mu_i^l)^2}{\Sigma_i^l} $$

In this model we will assume the variances to be 1, and letting $\epsilon_i^l = x_i^l - \mu_i^l$

$$ F = -\frac{1}{2} \sum_{l=0}^{L-1} \sum_{i=1}^n (\epsilon_i^l)^2 $$

to update each $x_i$ we will use the partial derivative of $F$ with respect to $x_i$

$$ \frac{\partial F}{x_i^l} = -\epsilon_i^l + f'(x_i^l) \sum_{k=1}^n \theta_{i,k}^l \epsilon_k^{l-1}$$

updating the weights
$$ \frac{\partial F}{\theta_{i,j}^l} = \epsilon_i^{l-1} f(x_j^l) $$

In [1]:
from read_image import *

train_images = read_mnist_images('data/train-images-idx3-ubyte.gz')
train_labels = read_mnist_labels('data/train-images-idx3-ubyte.gz')
test_images = read_mnist_images('data/t10k-images-idx3-ubyte.gz')
test_labels = read_mnist_labels('data/t10k-images-idx3-ubyte.gz')

In [32]:
import numpy as np

input = np.array([[0,0,1],
            [0,1,1],
            [1,0,1],
            [1,1,1]])

output = np.array([[0],
            [1],
            [1],
            [0]])

#w1 = 2*np.random.random((4,1)) - 1
#w2 = 2*np.random.random((3,4)) - 1
w1 = np.array([ [1.],    
                [2]])
w2 = np.array([ [1.,2],
                [3,4],
                [5,6]])
params = [w2, w1]

layers = [np.zeros((3,1)), np.zeros((2,1)), np.zeros((1,1))]

# prediction
gamma = 1
def predict(input: np.ndarray) -> np.ndarray:
    layers = [np.zeros((3,1)), np.zeros((2,1)), np.zeros((1,1))]
    curr_mu = params[0].T.dot(input)
    curr_err = layers[1] - curr_mu
    for i in range(1, len(layers)-1):
        activated = np.maximum(0, layers[i])
        next_mu = params[i].T.dot(activated)
        next_err = layers[i+1] - next_mu
        relu_mask = layers[i] > 0
        layers[i] += gamma * (-curr_err + relu_mask * params[i].dot(next_err))
        #during prediction can just do a normal forward pass through network
        curr_mu = params[i].T.dot(np.maximum(0, layers[i]))
        curr_err = layers[i+1] - curr_mu
    layers[-1] += gamma * -curr_err
    return layers[-1]

def learn(input: np.ndarray, out: np.ndarray, lr=0.01):
    layers[2] = out
    for t in range(0, len(layers)):
        curr_mu = params[0].T.dot(input)
        curr_err = layers[1] - curr_mu
        if t == len(layers) - 1:
            params[0] += lr * input.dot(curr_err.T)
        for i in range(1, len(layers)-1):
            activated = np.maximum(0, layers[i])
            next_mu = params[i].T.dot(activated)
            next_err = layers[i+1] - next_mu
            relu_mask = layers[i] > 0
            layers[i] += gamma * (-curr_err + relu_mask * params[i].dot(next_err))
            if len(layers) - 1 - i == t:
                activated = np.maximum(0, layers[i])
                next_mu = params[i].T.dot(activated)
                next_err = layers[i+1] - next_mu
                #print(params[i])
                params[i] += lr * activated.dot(next_err.T)
                #print(params[i])
            #can we avoid recalculating curr_mu during learning?
            #during prediction can just do a normal forward pass through network
            curr_mu = params[i].T.dot(np.maximum(0, layers[i]))
            curr_err = layers[i+1] - curr_mu
    
print(predict(input[1].reshape((3,1))))
learn(input[1].reshape((3,1)), output[1], 0.1)
#print(layers)
print(predict(input[1].reshape((3,1))))

[[28.]]
[[1.00005653]]
