# Multi-output federated SGPs

Performance on the development index classification dataset.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import logging
import pickle
import copy

module_path = os.path.abspath(os.path.join("/Users/matt/projects/pvi"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

torch.set_default_dtype(torch.float64)

In [None]:
from pvi.models import MOSparseGaussianProcessClassification
from pvi.distributions import MultivariateGaussianDistributionWithZ, MultivariateGaussianFactorWithZ
from pvi.clients.federated_sgp import FederatedSGPClient
from pvi.servers import GlobalVIServer
from pvi.servers.federated_sgp import SequentialSGPServer, SynchronousSGPServer
from pvi.models.sgp.kernels import RBFKernel
from pvi.utils.training_utils import EarlyStopping
from data.split_data import homogenous

# Load data and set up helper functions

In [None]:
from sklearn import datasets

iris = datasets.load_iris()
x = iris.data
y = iris.target

# Normalise data.
x = (x - x.mean(axis=0)) / x.std(axis=0)

ntrain = 100
perm = np.random.permutation(len(y))

train_data = {
    "x": torch.tensor(x[perm[:ntrain]]),
    "y": torch.tensor(y[perm[:ntrain]]),
}

test_data = {
    "x": torch.tensor(x[perm[ntrain:]]),
    "y": torch.tensor(y[perm[ntrain:]]),
}

In [None]:
def performance_metrics(client, data):
    x, y = data["x"], data["y"]
    
    pp = client.model_predict(x)
    preds = pp.component_distribution.probs.mean(1)
    
    mll = pp.log_prob(y).mean()
    acc = sum(torch.argmax(preds, dim=-1) == y) / len(y)
    
    metrics = {
        "acc": acc.item(),
        "mll": mll.item(),
    }
    
    return metrics

class Client():
    def __init__(self, data):
        self.data = data

# Model, client and server configurations

In [None]:
num_inducing = 10
D = x.shape[1]
P = 3

# Shared across all clients.
model_config = {
    "D": D,
    "P": P,
    "share_kernel": False,
    "num_inducing": num_inducing,
    "kernel_class": lambda **kwargs: RBFKernel(**kwargs),
    "kernel_params": {
        "ard_num_dims": D, 
        "train_hypers": False,    # Don't train kernel hyperparameters for now.
    },
    "num_predictive_samples": 100
}

server_config = {
    "max_iterations": 1,
    "optimiser": "Adam",
    "optimiser_params": {"lr": 1e-3},
    "epochs": 2000,
    "num_elbo_samples": 100,
    "print_epochs": 10,
    "performance_metrics": performance_metrics,
    "early_stopping": EarlyStopping(25),
    "train_model": False,
}

In [None]:
# Construct client.
client = Client(train_data)

# Construct server.
model = MOSparseGaussianProcessClassification(config=model_config)

# Randomly initialise global inducing points.
perm = torch.randperm(len(train_data["x"]))
z = torch.tensor(train_data["x"][perm[:num_inducing]])
kzz = model.kernel(z, z)

p = MultivariateGaussianDistributionWithZ(
    std_params = {
        "loc": torch.zeros(z.shape[0]).unsqueeze(0).repeat(P, 1),
        "covariance_matrix": kzz,
    },  
    inducing_locations=z,
    train_inducing=True,
    is_trainable=False
)

server = GlobalVIServer(model=model, p=p, clients=[client], config=server_config, val_data=test_data)

# Obtain prior predictions.
train_metrics = performance_metrics(server, train_data)
test_metrics = performance_metrics(server, test_data)
print("Test mll: {:.3f}. Test acc: {:.3f}.".format(test_metrics["mll"], test_metrics["acc"]))
print("Train mll: {:.3f}. Train acc: {:.3f}.\n".format(train_metrics["mll"], train_metrics["acc"]))

# Run PVI!
while not server.should_stop():
    server.tick()

    # Obtain predictions.
    train_metrics = performance_metrics(server, train_data)
    test_metrics = performance_metrics(server, test_data)
    print("Test mll: {:.3f}. Test acc: {:.3f}.".format(test_metrics["mll"], test_metrics["acc"]))
    print("Train mll: {:.3f}. Train acc: {:.3f}.\n".format(train_metrics["mll"], train_metrics["acc"]))

# Now try PVI

In [None]:
num_inducing = 5

client_config = {
    "optimiser": "Adam",
    "optimiser_params": {"lr": 1e-3},
    "epochs": 2000,
    "batch_size": 50,
    "num_elbo_samples": 10,
    "num_elbo_hyper_samples": 2,
    "valid_factors": True,
    "early_stopping": EarlyStopping(25),
    "damping_factor": 1.
}

# Shared across all clients.
model_config = {
    "D": D,
    "P": P,
    "share_kernel": False,
    "num_inducing": num_inducing,
    "kernel_class": lambda **kwargs: RBFKernel(**kwargs),
    "kernel_params": {
        "ard_num_dims": D, 
        "train_hypers": False
    },
    "num_predictive_samples": 100
}

server_config["max_iterations"] = 10

init_nat_params = {
    "np1": torch.zeros(model_config["num_inducing"]),
    "np2": torch.zeros(model_config["num_inducing"]).diag_embed(),
}

In [None]:
M = 2
clients_data = homogenous(train_data["x"], train_data["y"], m=M, dataset_seed=42)

In [None]:
# Construct clients.
clients = []
z_is = []

for i in range(M):
    model_i = MOSparseGaussianProcessClassification(config=model_config)
    data_i = clients_data[i]
    
    # Randomly initialise private inducing points.
    perm = torch.randperm(len(data_i["x"]))
    z_i = torch.tensor(data_i["x"][perm[:num_inducing]])
    z_is.append(z_i)
    
    # Convert to torch.tensor.
    for k, v in data_i.items():
        data_i[k] = torch.tensor(v)
    
    t = MultivariateGaussianFactorWithZ(
        nat_params=init_nat_params,
        inducing_locations=z_i,
        train_inducing=True,
    )
    
    clients.append(FederatedSGPClient(data=data_i, model=model_i, t=t, config=client_config))

# Construct global model and server.
model = MOSparseGaussianProcessClassification(config=model_config)
p = None

server = SequentialSGPServer(
    model=model, 
    p=p, 
    clients=clients, 
    config=server_config, 
    maintain_inducing=False
)

# Run PVI!
while not server.should_stop():
    server.tick()

    # Obtain predictions.
    train_metrics = performance_metrics(server, train_data)
    test_metrics = performance_metrics(server, test_data)
    print("Test mll: {:.3f}. Test acc: {:.3f}.".format(test_metrics["mll"], test_metrics["acc"]))
    print("Train mll: {:.3f}. Train acc: {:.3f}.\n".format(train_metrics["mll"], train_metrics["acc"]))