Skip to content

AaltoML/sfr

Repository files navigation

SFR - Sparse Function-space Representation of Neural Networks

This repository contains a clean and minimal PyTorch implementation of Sparse Function-space Representation (SFR) of Neural Networks. If you'd like to use SFR we recommend using this repo. Please see sfr-experiments for reproducing the experiments in the ICLR 2024 paper.

Function-space Parameterization of Neural Networks for Sequential Learning
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
International Conference on Learning Representations (ICLR 2024)
Paper Code Website
Sparse Function-space Representation of Neural Networks
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
ICML 2023 Workshop on Duality Principles for Modern Machine Learning
Paper Code Website

Install

CPU

Create an environment with:

conda env create -f env_cpu.yaml

Activate the environment with:

source activate sfr

NVIDIA GPU

Create an environment with:

conda env create -f env_nvidia.yaml

Activate the environment with:

source activate sfr

Useage

See the notebooks for how to use our code for both regression and classification.

Image Classification

We provide a minimal training script in train.py which can be used to train a CNN and fit SFR on MNIST/Fashion-MNIST/CIFAR-10. It is advised to run this on GPU.

Example

Here's a short example:

import src
import torch

torch.set_default_dtype(torch.float64)

def func(x, noise=True):
    return torch.sin(x * 5) / x + torch.cos(x * 10)

# Toy data set
X_train = torch.rand((100, 1)) * 2
Y_train = func(X_train, noise=True)
data = (X_train, Y_train)

# Training config
width = 64
num_epochs = 1000
batch_size = 16
learning_rate = 1e-3
delta = 0.00005  # prior precision
data_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(*data), batch_size=batch_size
)

# Create a neural network
network = torch.nn.Sequential(
    torch.nn.Linear(1, width),
    torch.nn.Tanh(),
    torch.nn.Linear(width, width),
    torch.nn.Tanh(),
    torch.nn.Linear(width, 1),
)

# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)
sfr = src.SFR(
    network=network,
    prior=src.priors.Gaussian(params=network.parameters, delta=delta),
    likelihood=src.likelihoods.Gaussian(sigma_noise=2),
    output_dim=1,
    num_inducing=32,
    dual_batch_size=None, # this reduces the memory required for computing dual parameters
    jitter=1e-4,
)

sfr.train()
optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learning_rate)
for epoch_idx in range(num_epochs):
    for batch_idx, batch in enumerate(data_loader):
        x, y = batch
        loss = sfr.loss(x, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

sfr.set_data(data) # This builds the dual parameters

# Make predictions in function space
X_test = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1)
f_mean, f_var = sfr.predict_f(X_test)

# Make predictions in output space
y_mean, y_var = sfr.predict(X_test)

Development

Set up pre-commit by running:

pre-commit install

Now when you commit the formatter/linter etc will automatically be run.

Citation

Please consider citing our conference paper

@inproceedings{scannellFunction2024,
  title           = {Function-space Prameterization of Neural Networks for Sequential Learning},
  booktitle       = {Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},
  author          = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
  year            = {2024},
  month           = {5},
}

Or our workshop

@inproceedings{scannellSparse2023,
  title           = {Sparse Function-space Representation of Neural Networks},
  maintitle       = {ICML 2023 Workshop on Duality Principles for Modern Machine Learning},
  author          = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
  year            = {2023},
  month           = {7},
}