# Federated SGP demo on banana dataset with private inducing points

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import logging

module_path = os.path.abspath(os.path.join("../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from pvi.models import SparseGaussianProcessClassification
from pvi.distributions import MultivariateGaussianDistributionWithZ, MultivariateGaussianFactorWithZ
from pvi.distributions import LogNormalDistribution, LogNormalFactor
from pvi.distributions.hypers import HyperparameterDistribution, HyperparameterFactor
from pvi.clients.federated_sgp import FederatedSGPClient
from pvi.servers.federated_sgp import SequentialSGPServer, SynchronousSGPServer
from pvi.utils.training_utils import EarlyStopping

import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm.auto as tqdm
import gpytorch

from torch import nn
from pvi.models.sgp.kernels import RBFKernel


%matplotlib inline
torch.set_default_dtype(torch.float64)

### Set up data and helper functions

In [None]:
def load_data(name, train_proportion, data_base_dir):
    filename = os.path.join(data_base_dir, name, "banana.csv")
    data = np.loadtxt(filename, delimiter=",", skiprows=1)
    
    x = data[:, :2]
    y = data[:, -1]
    
    # Replace 1's with 0's and 2's with 1's.
    y[y == 1] = 0
    y[y == 2] = 1
    
    N = x.shape[0]
    N_train = int(np.ceil(train_proportion * N))
    
    x_train = x[0:N_train]
    y_train = y[0:N_train]
    x_test = x[N_train:]
    y_test = y[N_train:]

    training_set = {
        "x": x_train,
        "y": y_train,
    }

    test_set = {
        "x": x_test,
        "y": y_test
    }

    D = x_test.shape[1]
    
    del data

    return training_set, test_set, D

def generate_clients_data(x, y, M, dataset_seed):
        random_state = np.random.get_state()

        if dataset_seed is not None:
            np.random.seed(dataset_seed)

        if M == 1:
            client_data = [{"x": x, "y": y}]
            N_is = [x.shape[0]]
            props_positive = [np.mean(y > 0)]

            return client_data, N_is, props_positive, M

        N = x.shape[0]
        client_size = int(np.floor(N/M))

        class_balance = np.mean(y == 0)

        pos_inds = np.where(y > 0)
        zero_inds = np.where(y == 0)
        
        assert (len(pos_inds[0]) + len(zero_inds[0])) == len(y), "Some indeces missed."
        
        print(f'x shape {x.shape}')

        y_pos = y[pos_inds]
        y_neg = y[zero_inds]

        x_pos = x[pos_inds]
        x_neg = x[zero_inds]

        client_data = []

        # Recombine remaining data and shuffle.
        x = np.concatenate([x_pos, x_neg])
        y = np.concatenate([y_pos, y_neg])
        
        # As in Bui et al, order according to x1 value.
        inds = np.argsort(x[:, 0])

        x = x[inds]
        y = y[inds]

        # Distribute among clients.
        for i in range(M):
            client_x = x[:client_size]
            client_y = y[:client_size]

            x = x[client_size:]
            y = y[client_size:]

            client_data.append({'x': client_x, 'y': client_y})

        N_is = [data['x'].shape[0] for data in client_data]
        props_positive = [np.mean(data['y'] > 0) for data in client_data]

        np.random.set_state(random_state)

        return client_data, N_is, props_positive, M
    
def plot_data(x, y, ax=None):
    x1_min, x1_max = -3., 3.
    x2_min, x2_max = -3., 3.
    
    x1x1, x2x2 = np.meshgrid(np.linspace(x1_min, x1_max, 100), 
                             np.linspace(x2_min, x2_max, 100))
    
    if ax is None:
        plt.figure(figsize=(8, 6), dpi=200)
        ax = plt.gca()
    
    ax.plot(x[y == 0, 0], x[y == 0, 1], "o", color="C1", label="Class 1", alpha=.5, zorder=1)
    ax.plot(x[y == 1, 0], x[y == 1, 1], "o", color="C0", label="Class 2", alpha=.5, zorder=1)
    
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.set_xlim(x1x1.min(), x1x1.max())
    ax.set_ylim(x2x2.min(), x2x2.max())
    ax.legend(loc="upper left", scatterpoints=1, numpoints=1)
    
    return x1x1, x2x2

def acc_and_ll(pred_probs, x, y):
    
    acc = np.mean((pred_probs > 0.5) == y)
    
    probs = torch.clip(torch.tensor(pred_probs), 0., 1.)
    loglik = torch.distributions.Bernoulli(probs=probs).log_prob(torch.tensor(y))
    loglik = loglik.mean().numpy()
    
    return acc, loglik

def plot_predictive_distribution(x, y, z, model=None, q=None, ax=None, x_old=None, y_old=None, z_old=None):
    x1x1, x2x2 = plot_data(x, y, ax)
    ax.scatter(z[:, 0], z[:, 1], color="r", marker="o", zorder=3)
    
    if ax is None:
        ax = plt.gca()
    
    if model is not None and q is not None:
        x_predict = np.concatenate((x1x1.ravel().reshape(-1, 1), 
                                    x2x2.ravel().reshape(-1, 1)), 1)
        
        with torch.no_grad():
            y_predict = model(torch.tensor(x_predict), q=q, diag=True)
            y_predict = y_predict.mean.numpy().reshape(x1x1.shape)
            
        cs2 = ax.contour(x1x1, x2x2, y_predict, colors=["k"], levels=[0.2, 0.5, 0.8], zorder=2)
        ax.clabel(cs2, fmt="%2.1f", colors="k", fontsize=14)
    
    if x_old is not None and y_old is not None:
        ax.plot(x_old[y_old == 0, 0], x_old[y_old == 0, 1], "o", color="C1", label="Class 1", alpha=.25)
        ax.plot(x_old[y_old == 1, 0], x_old[y_old == 1, 1], "o", color="C0", label="Class 2", alpha=.25)
        
        if z_old is not None:
            ax.scatter(z_old[:, 0], z_old[:, 1], color="r", marker="o", alpha=.25)

Change these to the directory of banana.csv

In [None]:
name = "banana"
data_base_dir = "/Users/matt/projects/pvi/datasets"
train_proportion = 0.08

training_set, test_set, D = load_data(
    name, train_proportion, data_base_dir)

In [None]:
M = 3

clients_data, nis, prop_positive, M = generate_clients_data(
    training_set["x"], 
    training_set["y"],
    M=M,
    dataset_seed=0,
)

### Set up clients

In [None]:
num_inducing = 10

# Shared across all clients.
model_config = {
    "D": D,
    "num_inducing": num_inducing,
    "kernel_class": lambda **kwargs: RBFKernel(**kwargs),
    "kernel_params": {
        "ard_num_dims": D, 
        "train_hypers": True
    },
    "num_predictive_samples": 100
}

client_config = {
    "optimiser": "Adam",
    "optimiser_params": {"lr": 1e-2},
    "epochs": 2000,
    "batch_size": len(clients_data[0]["x"]),
    "num_elbo_samples": 10,
    "num_elbo_hyper_samples": 2,
    "num_predictive_samples": 100,
    "train_model": False,
    "damping_factor": .25,
    "valid_factors": False,
    "early_stopping": EarlyStopping(50)
}

server_config = {
    "max_iterations": 50,
    "train_model": False,
    "hyper_optimiser": "SGD",
    "hyper_optimiser_params": {"lr": 1},
    "hyper_updates": 10,
    "optimiser": "Adam",
    "optimiser_params": {"lr": 1e-2},
    "epochs": 2000,
    "early_stopping": EarlyStopping(25, delta=1e-2),
    "num_elbo_samples": 10,
}

init_nat_params = {
    "np1": torch.zeros(model_config["num_inducing"]),
    "np2": torch.zeros(model_config["num_inducing"]).diag_embed(),
}

prior_nat_params = {
    "np1": torch.zeros(model_config["num_inducing"]),
    "np2": -0.5 * torch.ones(model_config["num_inducing"]).diag_embed(),
}

In [None]:
# Construct clients.
clients = []
z_is = []

torch.manual_seed(0)
np.random.seed(0)
for i in range(M):
    model_i = SparseGaussianProcessClassification(config=model_config)
    data_i = clients_data[i]
    
    # Randomly initialise private inducing points.
    perm = torch.randperm(len(data_i["x"]))
    idx = perm[:model_config["num_inducing"]]
    z_i = torch.tensor(data_i["x"][idx])
    z_is.append(z_i)
    
    # Convert to torch.tensor.
    for k, v in data_i.items():
        data_i[k] = torch.tensor(v)
    
    t = MultivariateGaussianFactorWithZ(
        nat_params=init_nat_params,
        inducing_locations=z_i,
        train_inducing=True,
    )
    
    clients.append(FederatedSGPClient(data=data_i, model=model_i, t=t, config=client_config))

# Construct global model and server.
model = SparseGaussianProcessClassification(config=model_config)

# Union of z_is.
z = torch.cat(z_is)
kzz = model.kernel(z, z)
q = MultivariateGaussianDistributionWithZ(
    std_params = {
        "loc": torch.zeros(z.shape[0]),
        "covariance_matrix": kzz,
    }, 
    inducing_locations=z,
    train_inducing=True
)

# Randomly initialise global inducing points.
perm = torch.randperm(len(training_set["x"]))
idx = perm[:10]
z = torch.tensor(training_set["x"][idx])
kzz = model.kernel(z, z)
q = MultivariateGaussianDistributionWithZ(
    std_params = {
        "loc": torch.zeros(z.shape[0]),
        "covariance_matrix": kzz,
    }, 
    inducing_locations=z,
    train_inducing=True
)

server = SequentialSGPServer(
    model=model, 
    p=q, 
    clients=clients,
    config=server_config,
    maintain_inducing=True,    # Set to False to use union of inducing points.
)

### Run streaming SGP with private inducing points!

In [None]:
while not server.should_stop():
    server.tick()

    # Obtain predictions.
    pp = server.model_predict(torch.tensor(test_set["x"]))

    preds = pp.mean.detach().numpy()
    test_acc, test_mll = acc_and_ll(preds, test_set["x"], test_set["y"])

    print(test_acc)
    print(test_mll)
    
    fig = plt.figure(figsize=(12, 6), dpi=100, constrained_layout=True)
    gs = fig.add_gridspec(2, 3)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2])
    ax4 = fig.add_subplot(gs[1, 1])

    for i, (client, ax) in enumerate(zip(clients, [ax1, ax2, ax3])):
        plot_predictive_distribution(
            client.data["x"], client.data["y"], z=client.t.inducing_locations, ax=ax)
        ax.set_title("Client {}".format(i))

    plot_predictive_distribution(
        training_set["x"], training_set["y"], z=server.q.inducing_locations,
        model=server.model, q=server.q, ax=ax4)

    plt.show()