In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import logging

module_path = os.path.abspath(os.path.join("../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from pvi.models.linear_regression import LinearRegressionModel
from pvi.utils.gaussian import mvstandard2natural, mvnatural2standard

import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm.auto as tqdm

from torch import nn

%matplotlib inline
torch.set_default_dtype(torch.float64)

# Set up data and helper functions

In [None]:
x = np.linspace(-1, 1, 10)
y = 2 * x + 3 * np.abs(x) * np.random.rand(len(x))

x = torch.tensor(x).unsqueeze(1)
y = torch.tensor(y).unsqueeze(1)

In [None]:
def plot_data(x, y):
    plt.figure()
    plt.grid(b=True)
    plt.scatter(x, y)
    plt.show()
    
def plot_results(x, y, model, q):
    pp = model(x, q)
    mean = pp.mean.detach()
    std = pp.variance.detach() ** 0.5
    
    w_samples = q["distribution"].sample((20,))
    
    plt.figure()
    plt.grid(b=True)
    
    for w in w_samples:
        plt.plot(x, x * w[0] + w[1], color='k', alpha=.1)
        
    plt.plot(x.squeeze(-1), mean)
    plt.fill_between(x.squeeze(-1), mean-1.96*std, mean+1.96*std, alpha=.25)
    plt.scatter(x, y)
    plt.show()
    
def plot_training(training_array):
    x_vals = np.arange(1, len(training_array)+1)
    plt.figure()
    plt.grid(b=True)
    plt.plot(x_vals, training_array)
    plt.ylabel('ELBO Loss')
    plt.xlabel('Step')
    plt.show()
    
data = {
    "x": x,
    "y": y,
}

In [None]:
plot_data(x, y)

# Construct linear regression model

In [None]:
hyperparameters = {
    "D": 1,
    "epochs": 1000,
    "optimiser_class": torch.optim.Adam,
    "optimiser_params": {"lr": 1e-2},
}

model = LinearRegressionModel(output_sigma=.25, hyperparameters=hyperparameters)

q = {
    "nat_params": {
        "np1": torch.zeros(hyperparameters["D"]+1),
        "np2": -0.5 * torch.eye(hyperparameters["D"]+1)
    }
}

q_mu, q_cov = mvnatural2standard(q["nat_params"]["np1"], q["nat_params"]["np2"])
q_dist = torch.distributions.MultivariateNormal(q_mu, covariance_matrix=q_cov)
q["distribution"] = q_dist

In [None]:
plot_results(x, y, model, q)

# Fit data

In [None]:
t = {
    "nat_params": {
        "np1": torch.zeros(model.hyperparameters["D"]+1),
        "np2": 0. * torch.eye(model.hyperparameters["D"]+1),
    }
}

## First without optimising model hyperparameters

In [None]:
q_new, t_new = model.conjugate_update(data, q, t)

In [None]:
qmu, qcov = mvnatural2standard(q_new["nat_params"]["np1"], q_new["nat_params"]["np2"])
q_dist = torch.distributions.MultivariateNormal(qmu, covariance_matrix=qcov)

q_new["distribution"] = q_dist

In [None]:
plot_results(x, y, model, q_new)

## Now optimising model hyperparameters

In [None]:
def fit(model, data, q, t_i):
    # Set up optimiser.
    if model.hyperparameters["optimiser_class"] is not None:
        optimiser = model.hyperparameters["optimiser_class"](
            model.parameters(), **model.hyperparameters["optimiser_params"]
        )
    else:
        optimiser = optim.Adam(
            model.parameters(), **model.hyperparameters["optimiser_params"]
        )
        
    # Local optimisation to find new parameters.
    training_curves = {
        "mll": [],
    }
    
    # Compute local factor and current global posterior.
    q, t_i = model.conjugate_update(data, q, t_i)

    q_mu, q_cov = mvnatural2standard(q["nat_params"]["np1"], q["nat_params"]["np2"])
    q_dist = torch.distributions.MultivariateNormal(q_mu, covariance_matrix=q_cov)
    q["distribution"] = q_dist
    
    epoch_iter = tqdm.tqdm(range(model.hyperparameters["epochs"]), desc="Epochs")
    for i in epoch_iter:
        # Compute cavity distribution.
        qcav = {
            "nat_params": {
                "np1": q["nat_params"]["np1"] - t_i["nat_params"]["np1"],
                "np2": q["nat_params"]["np2"] - t_i["nat_params"]["np2"],
            }
        }
        qcav_mu, qcav_cov = mvnatural2standard(qcav["nat_params"]["np1"], qcav["nat_params"]["np2"])
        qcav_dist = torch.distributions.MultivariateNormal(qcav_mu, qcav_cov)
        qcav["distribution"] = qcav_dist
        
        # Compute MLL.
        mll = model.mll(data, q)
        loss = -mll
        
        # Backwards step.
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        
        # Compute local factor and current global posterior.
#         with torch.no_grad():
        q, t_i = model.conjugate_update(data, q, t_i)

        q_mu, q_cov = mvnatural2standard(q["nat_params"]["np1"], q["nat_params"]["np2"])
        q_dist = torch.distributions.MultivariateNormal(q_mu, covariance_matrix=q_cov)
        q["distribution"] = q_dist
        
        training_curves["mll"].append(mll.item())
        
    return q, t_i, training_curves

In [None]:
q_new, t_new, training_curves = fit(model, data, q, t)

In [None]:
plot_training(training_curves["mll"])

In [None]:
plot_results(x, y, model, q_new)