In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import logging

module_path = os.path.abspath(os.path.join("../../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from pvi.models.linear_regression import LinearRegressionModel
from pvi.utils.gaussian import mvstandard2natural, mvnatural2standard
from pvi.clients.synchronous_client import SynchronousClient
from pvi.distributions.exponential_family_distributions import MultivariateGaussianDistribution
from pvi.distributions.exponential_family_factors import MultivariateGaussianFactor

import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

from torch import nn

%matplotlib inline
torch.set_default_dtype(torch.float64)

# Set up data and helper functions

In [None]:
x = np.linspace(-1, 1, 10)
y = 2 * x + 3 * np.abs(x) * np.random.rand(len(x))

x = torch.tensor(x).unsqueeze(1)
y = torch.tensor(y).unsqueeze(1)

In [None]:
def plot_data(x, y):
    plt.figure()
    plt.grid(b=True)
    plt.scatter(x, y)
    plt.show()
    
def plot_results(x, y, model, q):
    pp = model(x, q)
    mean = pp.mean.detach()
    std = pp.variance.detach() ** 0.5
    
    w_samples = q.distribution.sample((20,))
    
    plt.figure()
    plt.grid(b=True)
    
    for w in w_samples:
        plt.plot(x, x * w[0] + w[1], color='k', alpha=.1)
        
    plt.plot(x.squeeze(-1), mean)
    plt.fill_between(x.squeeze(-1), mean-1.96*std, mean+1.96*std, alpha=.25)
    plt.scatter(x, y)
    plt.show()
    
def plot_training(training_array):
    x_vals = np.arange(1, len(training_array)+1)
    plt.figure()
    plt.grid(b=True)
    plt.plot(x_vals, training_array)
    plt.ylabel('ELBO Loss')
    plt.xlabel('Step')
    plt.show()
    
data = {
    "x": x,
    "y": y,
}

In [None]:
plot_data(x, y)

# Construct linear regression model

In [None]:
hyperparameters = {
    "D": 1,
    "epochs": 1000,
    "optimiser": "Adam",
    "optimiser_params": {"lr": 1e-2},
}

model = LinearRegressionModel(output_sigma=.25, hyperparameters=hyperparameters)

q = MultivariateGaussianDistribution(
    nat_params={
        "np1": torch.zeros(hyperparameters["D"]+1),
        "np2": -0.5 * torch.eye(hyperparameters["D"]+1)
    }
)

In [None]:
plot_results(x, y, model, q)

# Fit data

In [None]:
t = MultivariateGaussianFactor(
    nat_params={
        "np1": torch.zeros(model.hyperparameters["D"]+1),
        "np2": 0. * torch.eye(model.hyperparameters["D"]+1),
    }
)

In [None]:
# Construct synchronous client.
client = SynchronousClient(data=data, model=model, t=t)

## Without optimising model (hyper-)parameters

In [None]:
q_new, t_new = client.update_q(q)

In [None]:
type(q_new) == client.model.conjugate_family

In [None]:
plot_results(x, y, client.model, q_new)