In [1]:
import jax.numpy as jnp
from scipy.optimize import minimize
from scipy.stats import multivariate_normal as mvn
import matplotlib.pyplot as plt

# Helper functions
sigmoid = lambda x: 1./(1 + jnp.exp(-x))

class LogisticRegression(object):
    """
    Bayesian Logistic Regression with MAP estimation.
    
    Model:
    - Prior: w ~ N(0, α⁻¹I)
    - Likelihood: y_n | w, x_n ~ Bernoulli(σ(w^T x_n))
    """
    
    def __init__(self, X, y, feature_transformation=lambda x: x, alpha=1.):
        # Store data and hyperparameters
        self.X0 = X
        self.y = y
        self.alpha = alpha
        self.feature_transformation = feature_transformation
        
        # Transform and standardize features
        self.X = feature_transformation(self.X0)
        self.X_mean = jnp.mean(self.X, 0)
        self.X_std = jnp.std(self.X, 0)
        self.X_std = jnp.where(self.X_std == 0, 1, self.X_std)
        self.X = self.preprocess(X)
        
        self.N, self.D = self.X.shape
        
        # Find MAP estimate
        self.w_MAP = self.get_MAP()
    
    def preprocess(self, X_):
        """Standardize features"""
        X = self.feature_transformation(X_)
        return (X - self.X_mean) / self.X_std
    
    def predict(self, X, w):
        """Compute p(y=1|X, w) = σ(Xw)"""
        return sigmoid(X @ w)
    
    def log_joint(self, w):
        """
        Log joint probability: log p(y, w)
        = log p(w) + Σ_n log p(y_n | x_n, w)
        """
        # Log prior: -½ α w^T w
        log_prior = -0.5 * self.alpha * jnp.sum(w**2)
        
        # Log likelihood: Σ_n [y_n log σ(f_n) + (1-y_n) log(1-σ(f_n))]
        p = self.predict(self.X, w)
        log_lik = jnp.sum(self.y * jnp.log(p) + (1 - self.y) * jnp.log(1 - p))
        
        return log_prior + log_lik
    
    def grad(self, w):
        """
        Gradient of log joint:
        ∇log p(y,w) = -Σ_n (p_n - y_n) x_n - α w
        """
        p = self.predict(self.X, w)
        err = p - self.y
        return -self.X.T @ err - self.alpha * w
    
    def hessian(self, w):
        """
        Hessian of log joint:
        H = -X^T diag(p(1-p)) X - αI
        """
        p = self.predict(self.X, w)
        v = p * (1 - p)
        return -self.X.T @ jnp.diag(v) @ self.X - self.alpha * jnp.eye(self.D)
    
    def get_MAP(self):
        """Find MAP estimate by optimization"""
        init_w = jnp.zeros(self.D)
        results = minimize(
            lambda w: -self.log_joint(w),
            jac=lambda w: -self.grad(w),
            x0=init_w
        )
        if not results.success:
            raise ValueError('Optimization failed')
        return results.x

class LaplaceApproximation(object):
    """
    Laplace approximation for posterior p(w|y) ≈ N(w|m, S)
    where m = w_MAP and S = -H⁻¹
    """
    
    def __init__(self, model):
        self.model = model
        
        # Posterior mean = MAP
        self.posterior_mean = model.w_MAP
        
        # Hessian at MAP
        self.posterior_hessian = model.hessian(model.w_MAP[None, :])[0]
        
        # Posterior covariance = -H⁻¹
        self.posterior_cov = -jnp.linalg.inv(self.posterior_hessian)
    
    def log_pdf(self, w):
        """Log density of approximate posterior"""
        return mvn.logpdf(w, self.posterior_mean, self.posterior_cov)
    
    def sample(self, key, num_samples):
        """Sample from approximate posterior"""
        return mvn.rvs(self.posterior_mean, self.posterior_cov, size=num_samples)