In [None]:
import time 
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

import numpy as np
from LANAM.models import LaNAM, NAM, BayesianLinearRegression
from LANAM.config import *
from LANAM.trainer import *
from LANAM.trainer.nam_trainer import train
from LANAM.data import *

from LANAM.utils.plotting import * 
from LANAM.utils.output_filter import OutputFilter
from LANAM.utils.wandb import *

from laplace import Laplace
from laplace import marglik_training as lamt
from laplace.curvature.backpack import BackPackGGN

In [None]:
%reload_ext autoreload
%autoreload 2

## Bayesian Linear Regression
$$
\begin{align*}
y &= \beta_0 X_0 + \beta_1 X_1 + \epsilon, \\
\epsilon &\sim \mathcal{N}(0, \sigma^2) \\
\beta_0, \beta_1 &\sim \mathcal{N}(0, 1)\\
\end{align*}
$$

Then the posterior $p(\beta|X, y) = \mathcal{N}(\mu_\beta, \Sigma_\beta)$

$$
\begin{align*}
\begin{cases}
\mu_\beta &= \sigma^{-2}A^{-1}X^Ty\\
\Sigma_\beta &= A^{-1}\\
A &= \sigma^{-2}X^TX + \Sigma^{-1}_\beta
\end{cases}
\end{align*}
$$

In [None]:
import numpy as np
from sklearn.linear_model import LinearRegression
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
# y = 1 * x_0 + 2 * x_1 + 3
y = np.dot(X, np.array([1, 2])) + 3
reg = LinearRegression().fit(X, y)
reg.score(X, y)
reg.coef_
reg.intercept_
reg.predict(np.array([[3, 5]]))

In [None]:
rhos=[0, 0.7, 0.9, 0.95, 0.99]

for rho in rhos: 
    data = linear_regression_example(rho=rho, sigma=1)
    input_0, input_1 = data.features.T
    target = data.targets
    data_plot = plot_3d(input_0, input_1, target)
    
    blr = BayesianLinearRegression(data.features, data.targets, bf='identity', sigma_noise=1.0, prior_var=1.0)
    
    #reg = LinearRegression().fit(data.features, y)
    
    mean = blr.mean
    cov = blr.posterior_cov
    
    x, y = np.random.multivariate_normal(mean.flatten(), cov, 1000).T
    fig, axs = plt.subplots(figsize=(4, 3))
    axs.plot(x, y, 'x')
    axs.axis('equal')
    axs.set_xlabel('beta_0')
    axs.set_ylabel('beta_1')
    plt.show()