## Lista 5 - Aprendizagem de Máquina Probabilístico
- Aluno: Lucas Rodrigues Aragão - Graduação 538390

In [1]:
import numpy as np
import pandas as pd
import scipy as sp
import jax
import jax.numpy as jnp
from jax.scipy.linalg import cholesky


## Modelos

### Inferência com modelos de GP para regressão 
- Estimação 

    1. Inicializar hiperparâmetros $\theta = \big[  \sigma^2_f, w^2_1, \cdots, w^2_D, \sigma^2_y \big]^T$

    2. Repetir até convergir ou número de épocas

        1. Calcular a evidência do modelo $\log p(y|X, \theta)$. A evidência é calculada via Cholesky.
            $$K+ \sigma^2_y = LL^T, \alpha = L^{-1}y$$
            $$\mathcal{L} (\theta) = - \sum_{i}{\log L_{ii} - \frac{1}{2} \alpha^T \alpha - \frac{N}{2} \log (2 \pi)}$$
        2. Calcular os gradientes analíticos $\frac{\partial \log p(y|X, \theta)}{\partial \theta}$.
        3. Atualizar $\theta$ a partir dos gradientes.

    3. Retornar os hiperparâmetros

- Predição

1. Dado um novo padrão $x_\ast$ retornar a distribuição preditiva 

$$p(y| x_\ast, y, X, \hat{\theta}) = \mathcal{N}(y_\ast| \mu_\ast, \sigma^2_\ast + \sigma^2_y)$$

$$\mu_\ast = k_{f \ast}^T (K + \sigma^2_y I)^{-1} y$$

$$\sigma^2_\ast = k_{\ast \ast} - k_{f \ast}^T (K + \sigma^2_y I)^{-1} k_{f \ast}$$

Em que, $k_{f \ast} = [k(x_\ast, x_1), \cdots , k(x_\ast, x_N)]$ e $k_{\ast \ast} = k(x_\ast, x_\ast)$

Além disso, valores de inicialização comuns são, $\sigma^2_f = \mathbb{V}[y]$, $w^2_d = \frac{1}{\mathbb{V}[X_{:d}]}$ e $\sigma^2_y = 0.01 \sigma^2_f$, com o $\mathbb{V}$ sendo a variância.


In [None]:
class RegGP:
    def __init__(self, sigma2f, w0, sigma2y):
        self.theta = jnp.array([sigma2f, w0, sigma2y])

    def RBF(self, theta, xi, xj):
        sigma2f = theta[0]
        diff = xi-xj
        w = theta[1:-1] 
        temp = jnp.sum(w * diff**2)
        return sigma2f * jnp.exp(-0.5 * temp)
        
    def apply_kernel(self, X, theta):
        return jax.vmap(lambda xi: jax.vmap(lambda xj: self.RBF(theta , xi, xj))(X))(X)


    def estimate(self, X_train, y_train ,epochs, lr):
        self.y = y_train
        grad_evidence = jax.grad(self.evidence)
        theta = self.theta

        for epoch in range(epochs):
            #calcular a evidencia log P(y|X,theta), via cholesky            
            #calcular os gradientes 
            # atualizar os hiperparametros
            grads = grad_evidence(theta=theta,X=X_train, y= y_train)
            theta = theta + lr*grads
            
        self.theta = theta
        return theta
    
    def evidence(self, theta, X ,y):
        K = self.apply_kernel(theta, X)
        sigma2y = theta[-1]
        N = y.shape[0]
        L = cholesky(K + sigma2y * jnp.eye(N))
        alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L,y))
        evidencia = - jnp.sum(jnp.log(jnp.diagonal(L))) - 0.5 * alpha.T @ alpha - 0.5 * N * jnp.log(2 * jnp.pi)

        return evidencia

    def predict(self, X_ast):
        #TODO: testar predict
        theta = self.theta
        sigma2y = theta[-1] 

        K = self.apply_kernel(X_ast, theta)
        kf = jax.vmap(lambda xi: self.RBF(theta, X_ast, xi) (X_ast))
        K_sigma = jnp.linalg.inv((K + sigma2y * jnp.eye(X_ast.shape[0])))
        mu_ast = kf.T @ K_sigma @ self.y
        k_ast_ast = self.RBF(theta, X_ast, X_ast)
        sigma2_ast = k_ast_ast - kf.T @ K_sigma @ kf

        return np.random.multivariate_normal(mean= mu_ast, cov= sigma2_ast + sigma2y)
    

## Métricas
$$RMSE = \sqrt{\frac{\sum_{i = 1}^{N_{teste}} (y_i - \hat{y_i})^2}{N}}$$
$$NLPD = \frac{1}{2} \log 2 \pi + \frac{1}{2N_{teste}} \sum^{N_{teste}}_{i=1} \big [ \log \hat{\sigma_i}^2  + \frac{(y_i- \hat{\mu_i})^2}{\hat{\sigma}_i^2}\big]$$


In [19]:
def RMSE(y_true, y_pred):

    return np.sqrt((np.sum((y_true - y_pred)**2)/y_pred.shape[0]))

def NLPD(y_true, pred_var, pred_mean):
    
    temp = np.sum( np.log(pred_var) + (y_true - pred_mean)**2 /pred_var)
    return 0.5 * np.log(2 * np.pi) + 1/(2 * y_true.shape[0]) * temp

## Questão 1

In [2]:
data = pd.read_csv("gp_data_train.csv", header= None)
data

Unnamed: 0,0,1
0,0.392938,-0.302971
1,-0.427721,0.324264
2,-0.546297,0.314647
3,0.102630,0.384257
4,0.438938,-0.398080
...,...,...
95,0.383404,-0.567651
96,-0.697745,-0.190597
97,-0.202247,0.049610
98,-0.518288,0.209381


In [21]:
X = jnp.asarray(data[0])
y = jnp.asarray(data[1])

sigma2f = jnp.var(y)
sigma2y = 0.001 * sigma2f
w0 = 1/jnp.var(X)

modelGP = RegGP(sigma2f = sigma2f, w0 = w0, sigma2y= sigma2y)

modelGP.estimate(X, y, epochs= 5, lr= 0.001)


TypeError: differentiating with respect to argnums=0 requires at least 1 positional arguments to be passed by the caller, but got only 0 positional arguments.