In [1]:

import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:

# Download and load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Use a smaller subset for faster training (for demonstration)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

100.0%
100.0%
100.0%
100.0%


In [None]:

class PredictiveCodingModel:
    def __init__(self, input_size=784, hidden_size=256, output_size=10, lr=1e-3):
        # Feedforward weights
        self.W1 = np.random.randn(hidden_size, input_size) * 0.1
        self.W2 = np.random.randn(output_size, hidden_size) * 0.1
        # Lateral weights (hidden layer)
        self.L1 = np.eye(hidden_size) * 0.01  # Small lateral weights
        # States
        self.x1 = np.zeros((hidden_size, 1))
        self.x2 = np.zeros((output_size, 1))
        # Learning rate
        self.lr = lr

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def softmax(self, x):
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum(axis=0, keepdims=True)

    def forward(self, x0):
        # x0: input (784, 1)
        # Predict hidden state
        pred_x1 = self.sigmoid(self.W1 @ x0 + self.L1 @ self.x1)
        # Predict output
        pred_x2 = self.softmax(self.W2 @ pred_x1)
        return pred_x1, pred_x2

    def compute_errors(self, x0, x1, x2, y):
        # Prediction errors
        e0 = x0 - self.W1.T @ x1  # Input error
        e1 = x1 - self.sigmoid(self.W1 @ x0 + self.L1 @ x1)  # Hidden error
        e2 = x2 - y  # Output error (y is one-hot)
        return e0, e1, e2

    def update_states(self, x0, x1, x2, y, e0, e1, e2):
        # Update hidden state (x1)
        dx1 = self.W1 @ e0 + self.L1 @ e1 - e1
        x1 += self.lr * dx1
        # Update output state (x2)
        dx2 = -e2
        x2 += self.lr * dx2
        return x1, x2

    def update_weights(self, x0, x1, x2, e0, e1, e2):
        # Update feedforward weights
        self.W1 += self.lr * (np.outer(e1, x0))
        self.W2 += self.lr * (np.outer(e2, x1))
        # Update lateral weights (hidden)
        self.L1 += self.lr * (np.outer(e1, x1))

    def predict(self, x0):
        x1 = self.sigmoid(self.W1 @ x0 + self.L1 @ np.zeros_like(self.x1))
        x2 = self.softmax(self.W2 @ x1)
        return np.argmax(x2)