# Discriminative PC on MNIST

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/discriminative_pc.ipynb)

This notebook demonstrates how to train a simple feedforward network with predictive coding (PC) to discriminate or classify MNIST digits.

In [1]:
# %%capture
# !pip install torch==2.3.1
# !pip install torchvision==0.18.1

In [2]:
import jpc

import jax
import equinox as eqx
import equinox.nn as nn
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.simplefilter('ignore')  # ignore warnings
from pathlib import Path
import numpy as np
import os
import pickle

## Hyperparameters

We define some global parameters, including the network architecture, learning rate, batch size, etc.

In [3]:
SEED = 0

INPUT_DIM = 784
WIDTH = 300
DEPTH = 3
OUTPUT_DIM = 10
ACT_FN = "relu"

LEARNING_RATE = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 300

## Dataset

Some utils to fetch MNIST.

In [4]:
def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]
    

## Network

For `jpc` to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like [`nn.Sequential()`](https://docs.kidger.site/equinox/api/nn/sequential/#equinox.nn.Sequential) in [equinox](https://github.com/patrick-kidger/equinox). For example, we can define a ReLU MLP with two hidden layers as follows

In [5]:
key = jax.random.PRNGKey(SEED)
_, *subkeys = jax.random.split(key, 4)
network = [
    nn.Sequential(
        [
            nn.Linear(784, 300, key=subkeys[0]),
            nn.Lambda(jax.nn.relu)
        ],
    ),
    nn.Sequential(
        [
            nn.Linear(300, 300, key=subkeys[1]),
            nn.Lambda(jax.nn.relu)
        ],
    ),
    nn.Linear(300, 10, key=subkeys[2]),
]



You can also use [`jpc.make_mlp()`](https://thebuckleylab.github.io/jpc/api/Utils/#jpc.make_mlp) to define a multi-layer perceptron (MLP) or fully connected network.

In [6]:
network = jpc.make_mlp(
    key,
    input_dim=INPUT_DIM,
    width=WIDTH,
    depth=DEPTH,
    output_dim=OUTPUT_DIM,
    act_fn=ACT_FN,
    use_bias=True
)
print(network)

[Sequential(
  layers=(
    Lambda(fn=Identity()),
    Linear(
      weight=f32[300,784],
      bias=f32[300],
      in_features=784,
      out_features=300,
      use_bias=True
    )
  )
), Sequential(
  layers=(
    Lambda(fn=<PjitFunction of <function relu at 0x77699a43ec00>>),
    Linear(
      weight=f32[300,300],
      bias=f32[300],
      in_features=300,
      out_features=300,
      use_bias=True
    )
  )
), Sequential(
  layers=(
    Lambda(fn=<PjitFunction of <function relu at 0x77699a43ec00>>),
    Linear(
      weight=f32[10,300],
      bias=f32[10],
      in_features=300,
      out_features=10,
      use_bias=True
    )
  )
)]


## Train and test

A PC network can be updated in a single line of code with [`jpc.make_pc_step()`](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_pc_step). Similarly, we can use [`jpc.test_discriminative_pc()`](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_discriminative_pc) to compute the network accuracy. Note that these functions are already "jitted" for optimised performance. Below we simply wrap each of these functions in training and test loops, respectively.

In [7]:
def record_logs(writer, iteration, result):
    writer.add_scalar(f'base/loss', float(result["loss"]), iteration)
    writer.add_scalar(f'base/t_max', float(result["t_max"]), iteration)
    if "acc" in result:
        writer.add_scalar(f'base/acc', float(result["acc"]), iteration)

    # Log gradient norms for each layer
    for layer_idx, param_grad in enumerate(result['model_param_grads']):
        writer.add_scalar(f'grad/layer_{layer_idx}_weight_grad', float(jax.numpy.linalg.norm(param_grad[1].weight)), iteration)
        writer.add_scalar(f'grad/layer_{layer_idx}_bias_grad', float(jax.numpy.linalg.norm(param_grad[1].bias)), iteration)

    # Log energies for each layer - they will be grouped by custom scalars config
    for layer_idx, energy in enumerate(result["energies"]):
        writer.add_scalar(f'energies/layer_{layer_idx}', float(energy), iteration)

    # log activity for each layer
    for layer_idx, activity in enumerate(result["activities"]):
        writer.add_scalar(f'activations_raw/layer_{layer_idx}', float(jax.numpy.linalg.norm(activity)), iteration)



    # log activity norms for each layer
    if result["activity_norms"] is not None:
        for layer_idx, activity in enumerate(result["activity_norms"]):
            writer.add_scalar(f'activations_norm/layer_{layer_idx}', float(jax.numpy.linalg.norm(activity)), iteration)


    # log momentums for each layer from optimizer result["opt_state"][0].mu
    # mu[0]: model; mu[1]: None (skip model)
    for layer_idx, momentum in enumerate(result['opt_state'][0].mu[0]):
        writer.add_scalar(f'momentum/layer_{layer_idx}_weight_momentum', float(jax.numpy.linalg.norm(momentum[1].weight)), iteration)
        writer.add_scalar(f'momentum/layer_{layer_idx}_bias_momentum', float(jax.numpy.linalg.norm(momentum[1].bias)), iteration)


def evaluate(model, test_loader):
    avg_test_loss, avg_test_acc = 0, 0
    for _, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
            
        test_loss, test_acc = jpc.test_discriminative_pc(
            model=model,
            input=img_batch,
            output=label_batch
        )
        avg_test_loss += test_loss
        avg_test_acc += test_acc

    return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)


def train(
      model,
      lr,
      batch_size,
      test_every,
      n_train_iters,
      writer = None,
      log_every = 10,
    #   save_gradients = True
):
    optim = optax.adam(lr)
    opt_state = optim.init(
        (eqx.filter(model, eqx.is_array), None)
    )
    train_loader, test_loader = get_mnist_loaders(batch_size)


    CALCULATE_ACCURACY = True
    ACTIVITY_NORMS = True
    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        result = jpc.make_pc_step(
            model=model,
            optim=optim,
            opt_state=opt_state,
            output=label_batch,
            input=img_batch,
            calculate_accuracy = CALCULATE_ACCURACY,
            activity_norms = ACTIVITY_NORMS
        )

        if writer is not None and (iter % log_every) == 0:
            record_logs(writer, iter, result)

        model, opt_state = result["model"], result["opt_state"]
        train_loss = result["loss"]
        if ((iter+1) % test_every) == 0:
            _, avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break


## Run

In [8]:
def create_writer(log_dir):
    if type(log_dir) is str:
        log_dir = Path(log_dir)

    paths = sorted(list(log_dir.glob('*/')))
    if len(paths) == 0:
        return SummaryWriter(log_dir=log_dir / 'run_000')
    last_index = int(paths[-1].name.split('_')[-1])
    new_index = last_index + 1
    return SummaryWriter(log_dir=log_dir / f'run_{new_index:03d}')


writer = create_writer(log_dir='runs/my_experiment')

train(
    model=network,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS,
    writer = writer
)

Train iter 100, train loss=0.009165, avg test accuracy=93.409454
Train iter 200, train loss=0.006444, avg test accuracy=95.402641
Train iter 300, train loss=0.005436, avg test accuracy=95.753204


In [9]:
grad_path = sorted((Path(writer.log_dir) / 'gradients').glob('*.pkl'))

In [15]:
with open(grad_path[0], 'rb') as f:
    gradient_data = pickle.load(f)

gradient_data.keys()

dict_keys(['layer_0_weight', 'layer_0_bias', 'layer_1_weight', 'layer_1_bias', 'layer_2_weight', 'layer_2_bias'])

In [None]:
            # for key, value in result.items():
            #     if key not in ["model", "skip_model", "opt_state", "activities"] and value is not None:
            #         writer.add_scalar(f'train/{key}', value, iter)

## Visualize Saved Gradients

After training, you can load and visualize the saved gradient data offline.

In [None]:
import matplotlib.pyplot as plt
import glob

def plot_gradients_from_file(grad_file, max_params=200):
    """Load and plot gradients from a saved pickle file."""
    with open(grad_file, 'rb') as f:
        gradient_data = pickle.load(f)
    
    # Extract iteration number from filename
    iter_num = int(grad_file.split('_iter_')[-1].split('.pkl')[0])
    
    # Get number of layers
    n_layers = len([k for k in gradient_data.keys() if 'weight' in k])
    
    fig, axes = plt.subplots(n_layers, 1, figsize=(12, 3*n_layers))
    if n_layers == 1:
        axes = [axes]
    
    for layer_idx in range(n_layers):
        weight_grad = gradient_data[f'layer_{layer_idx}_weight']
        grad_flat = weight_grad.flatten()
        
        # Sample if too many parameters
        if len(grad_flat) > max_params:
            indices = np.linspace(0, len(grad_flat)-1, max_params, dtype=int)
            grad_subset = grad_flat[indices]
        else:
            grad_subset = grad_flat
        
        # Plot
        axes[layer_idx].bar(range(len(grad_subset)), grad_subset, width=1.0)
        axes[layer_idx].set_xlabel('Parameter Index')
        axes[layer_idx].set_ylabel('Gradient Value')
        axes[layer_idx].set_title(f'Layer {layer_idx} Weight Gradients (Iteration {iter_num})')
        axes[layer_idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Example: Plot gradients from a specific iteration
# grad_dir = 'runs/my_experiment/gradients'
# grad_files = sorted(glob.glob(f'{grad_dir}/gradients_iter_*.pkl'))
# if grad_files:
#     # Plot the first saved iteration
#     fig = plot_gradients_from_file(grad_files[0])
#     plt.show()
#     
#     # Or plot all iterations
#     for grad_file in grad_files:
#         fig = plot_gradients_from_file(grad_file)
#         plt.savefig(grad_file.replace('.pkl', '.png'))
#         plt.close()