In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import random
from collections import OrderedDict
import os
import numpy as np

import predictive_coding as pc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'using {device}')

using cuda


In [2]:
# This class contains the parameters of the prior mean \mu parameter (see figure)
class BiasLayer(nn.Module):
    def __init__(self, num_features, offset=0.):
        super().__init__()
        self.bias = nn.Parameter(offset * torch.ones(num_features))

    def forward(self, x):
        return torch.zeros_like(x) + self.bias

# function to add noise to the inference dynamics of the PC layers
def random_step(t, _trainer, var=2.):
    """var: sets the variance of the noise.
    """
    xs = _trainer.get_model_xs()
    optimizer = _trainer.get_optimizer_x()
    for x in xs:
        x.grad.normal_(0.,np.sqrt(var/optimizer.defaults['lr']))
    optimizer.step()

def loss_fn(output, _target):
    return (output - _target).pow(2).sum() * 0.5

In [3]:
# test function
def test(model, loader):
    test_model = nn.Sequential(
        BiasLayer(options.layer_sizes[0], offset=0.1),
        pc.PCLayer(),
        model
    ).to(device)

    test_model.train()

    test_trainer = pc.PCTrainer(
        test_model, 
        T=200,
        update_x_at='all',
        optimizer_x_fn=optim.SGD,
        optimizer_x_kwargs={"lr": 0.1},
        update_p_at='never', # do not update parameters during inference
        plot_progress_at=[],
        x_lr_discount=0.5,
    )
    
    correct_cnt = 0
    for data, target in loader:
        data = data.to(device)
        target = F.one_hot(target, num_classes=10).to(torch.float32).to(device)
        results = test_trainer.train_on_batch(
            inputs=target,
            loss_fn=loss_fn,
            loss_fn_kwargs={'_target': data},
            is_log_progress=False,
            is_return_results_every_t=False,
        )
        pred = test_model[1].get_x()
        correct = pred.argmax(dim=1).eq(target.argmax(dim=1)).sum().item()
        correct_cnt += correct
    return correct_cnt / len(loader.dataset)

In [4]:
class Options:
    pass
options = Options()

options.batch_size = 500
options.train_size = 10000
options.test_size = 1000
options.layer_sizes = [10, 256, 256, 784]
options.activation = nn.Tanh()

def get_mnist(options):
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Lambda(lambda x: torch.flatten(x))])
    train = datasets.MNIST('./data', train=True, transform=transform, download=True)
    test = datasets.MNIST('./data', train=False, transform=transform, download=True)
    
    if options.train_size != len(train):
        train = torch.utils.data.Subset(train, random.sample(range(len(train)), options.train_size))
    if options.test_size != len(test):
        test = torch.utils.data.Subset(test, random.sample(range(len(test)), options.test_size))

    # Split the training set into training and validation sets
    train_len = int(len(train) * 0.9)  # 80% of data for training
    val_len = len(train) - train_len   # remaining 20% for validation
    train_set, val_set = torch.utils.data.random_split(train, [train_len, val_len])
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=options.batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=options.batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test, batch_size=options.batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [5]:
T_mixing = 100
T_sampling = 100

model = nn.Sequential(
    nn.Linear(options.layer_sizes[0], options.layer_sizes[1]), # generates the prediction
    pc.PCLayer(),
    options.activation,
    nn.Linear(options.layer_sizes[1], options.layer_sizes[2]),
    pc.PCLayer(),
    options.activation,
    nn.Linear(options.layer_sizes[2], options.layer_sizes[3]),
).to(device)

# sample a batch of data
train_loader, val_loader, test_loader = get_mnist(options)

model.train()

# warm-up inference for mcpc
inference_trainer = pc.PCTrainer(
    model, 
    T = T_mixing+T_sampling, 
    update_x_at = 'all', 
    optimizer_x_fn = optim.Adam,
    optimizer_x_kwargs = {'lr': 0.1},
    update_p_at = 'never',   
    plot_progress_at = [],
)

# training the model
train_trainer = pc.PCTrainer(
    model, 
    T = T_mixing+T_sampling, 
    update_x_at = 'all', 
    optimizer_x_fn = optim.Adam,
    optimizer_x_kwargs = {'lr': 0.1},
    update_p_at = 'last',   
    accumulate_p_at=[i + T_mixing for i in range(T_sampling)],
    optimizer_p_fn = optim.Adam,
    optimizer_p_kwargs = {"lr": 0.001, "weight_decay":0.001},
    plot_progress_at= [],
)

In [6]:
for i in range(10):
    print(f'epoch {i}')
    correct_cnt = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        target = F.one_hot(target, num_classes=10).to(torch.float32).to(device)

        # initialise sampling
        pc_results = inference_trainer.train_on_batch(
            inputs=target, 
            loss_fn=loss_fn,
            loss_fn_kwargs={'_target':data},
            is_log_progress=False,
            is_return_results_every_t=False,
            is_checking_after_callback_after_t=False
        )
        # mc inference
        mc_results = train_trainer.train_on_batch(
            inputs=target,
            loss_fn=loss_fn,
            loss_fn_kwargs={'_target': data}, 
            callback_after_t=random_step, 
            callback_after_t_kwargs={'_trainer': train_trainer},
            is_sample_x_at_batch_start=False,
            is_log_progress=False,
            is_return_results_every_t=False,
            is_checking_after_callback_after_t=False
        )

    print(f'training accuracy: {test(model, train_loader)}')
    print(f'validation accuracy: {test(model, val_loader)}')

print(f'test accuracy: {test(model, test_loader)}')

epoch 0




training accuracy: 0.2871111111111111
validation accuracy: 0.252
epoch 1
training accuracy: 0.45111111111111113
validation accuracy: 0.436
epoch 2


KeyboardInterrupt: 