In [4]:

import pandas as pd
import numpy as np
import jax.numpy as jnp
import jax
import jax.scipy as jsp
import jax.scipy.linalg as linalg
import matplotlib.pyplot as plt
from jax import jit, vmap
from jax.scipy.special import logsumexp
from jax import random

# distributions
from scipy.stats import norm
from scipy.stats import multivariate_normal as mvn
from scipy.stats import poisson


In [None]:
# Loading data
data = jnp.load("./data_assignment3.npz")
x, y = data["x"], data["t"]

# load with NumPy
data = np.load("./data_assignment3.npz")
x_np, y_np = data["x"], data["t"]

In [None]:
# Normal distribution
log_npdf = lambda x, m, v: -0.5*(x-m)**2/(v) - 0.5*jnp.log(2*jnp.pi*v)
npdf = lambda x, m, v: jnp.exp(log_npdf(x, m, v))

# Half-normal distribution
log_half_npdf = lambda x, m, v: jnp.log(2) -0.5*(x-m)**2/(v) - 0.5*jnp.log(2*jnp.pi*v)
half_npdf = lambda x, m, v: jnp.exp(log_half_npdf(x, m, v))

In [6]:
2.2
def evaluate_log_joint(theta):
    """
    Evaluate the marginalized log joint distribution
    log p(y, w0, w1, v, sigma0, sigma1, tau | x)
    
    Inputs:
        x         : (N, D) JAX array (feature matrix, NO bias yet)
        y         : (N,)   JAX array (targets)
        w0, w1, v : (D+1,) JAX arrays (parameters including bias term)
        sigma0    : scalar
        sigma1    : scalar
        tau       : scalar
    Returns:
        log_joint : scalar (log probability)
    """
    w0, w1, v, sigma0, sigma1, tau = theta

    # Helper: log pdf of normal
    def log_normal(y, mu, sigma):
        return -0.5 * (jnp.log(2 * jnp.pi) + 2 * jnp.log(sigma) + ((y - mu) / sigma)**2)

    # Helper: log prior of half-normal
    def log_half_normal(x):
        return jnp.where(x > 0,
                         jnp.log(2.0) + log_normal(x, 0.0, 1.0),
                         -jnp.inf)

    # Mixture likelihood per datapoint
    def per_datapoint_loglik(x_n, y_n):
        pi_n = jax.nn.sigmoid(jnp.dot(v, x_n))
        mu0 = jnp.dot(w0, x_n)
        mu1 = jnp.dot(w1, x_n)
        log_p0 = log_normal(y_n, mu0, sigma0)
        log_p1 = log_normal(y_n, mu1, sigma1)
        return logsumexp(jnp.array([
            jnp.log1p(-pi_n) + log_p0,
            jnp.log(pi_n) + log_p1
        ]))

    # Vectorized over all datapoints
    total_loglik = jnp.sum(vmap(per_datapoint_loglik)(x_np, y))

    # Priors
    logp_w0 = mvn.logpdf(w0.ravel(), mean=mean, cov=cov)
    logp_w1 = mvn.logpdf(w1.ravel(), mean=mean, cov=cov)
    logp_v  = mvn.logpdf(v .ravel(), mean=mean, cov=cov)
    logp_tau = log_half_normal(tau)
    logp_sigma0 = log_half_normal(sigma0)
    logp_sigma1 = log_half_normal(sigma1)

    total_logprior = logp_w0 + logp_w1 + logp_v + logp_tau + logp_sigma0 + logp_sigma1

    return total_loglik + total_logprior

if x.ndim == 1:
    x = x[:, None]

# Random parameters
key = jax.random.PRNGKey(0)
D = x.shape[0]
tau      = np.abs(np.random.randn())        # same as Half-Normal with σ=1
sigma0  = np.abs(np.random.randn())
sigma1   = np.abs(np.random.randn())

# Construct multivariate normal vectors for w0, w1, v
mean = np.zeros(D)
cov = tau**2 * np.eye(D)

w0 = np.random.normal(loc=0.0, scale=tau, size=(D,))
w1 = np.random.normal(loc=0.0, scale=tau, size=(D,))
v  = np.random.normal(loc=0.0, scale=tau, size=(D,))
logp_w0 = mvn.logpdf(w0.ravel(), mean=mean, cov=cov)
logp_w1 = mvn.logpdf(w1.ravel(), mean=mean, cov=cov)
logp_v  = mvn.logpdf(v .ravel(), mean=mean, cov=cov)
theta = [w0, w1, v, sigma0, sigma1, tau]

# Evaluate
logp = evaluate_log_joint(theta)

print("Marginalized log joint =", float(logp))

Marginalized log joint = -1121.0133056640625


In [7]:
def metropolis(log_target, num_params, epsilon, num_iter, theta_init=None, seed=0):    
    """ Runs a Metropolis-Hastings sampler 
    
        Arguments:
        log_target:         function for evaluating the log target distribution, i.e. log \tilde{p}(theta). The function expect a parameter of size num_params.
        num_params:         number of parameters of the joint distribution (integer)
        sigma:                standard deviation of the Gaussian proposal distribution (positive real)
        num_iter:           number of iterations (integer)
        theta_init:         vector of initial parameters (np.array with shape (num_params) or None)        
        seed:               seed (integer)

        returns
        thetas              np.array with MCMC samples (np.array with shape (num_iter+1, num_params))
    """ 
    
    # set initial key
    key = random.PRNGKey(seed)

    if theta_init is None:
        theta_init = jnp.zeros((num_params))
    
    # prepare lists 
    thetas = [theta_init]
    accepts = []
    log_p_theta = log_target(theta_init)
    
    for k in range(num_iter):

        # update keys: key_proposal for sampling proposal distribution and key_accept for deciding whether to accept or reject.
        key, key_proposal, key_accept = random.split(key, num=3)
        key_list = random.split(key_proposal, 6)

        ##############################################
        # Your solution goes here
        ##############################################
        

        # get the last value for theta and generate new proposal candidate
        theta_cur = thetas[-1]
        theta_star = [theta_cur[i] + epsilon * random.normal(key_list[i], shape=(len(theta_cur),))
              for i in range(6)]
        
        # evaluate the log density for the candidate sample
        log_p_theta_star = log_target(theta_star)

        # compute acceptance probability
        log_r = log_p_theta_star - log_p_theta
        A = min(1, jnp.exp(log_r))
        
        # accept new candidate with probability A
        if random.uniform(key_accept) < A:
            theta_next = theta_star
            log_p_theta = log_p_theta_star
            accepts.append(1)
        else:
            theta_next = theta_cur
            accepts.append(0)


        
        ##############################################
        # End of solution
        ##############################################
            
        thetas.append(theta_next)


        
    print('Acceptance ratio: %3.2f' % jnp.mean(jnp.array(accepts)))
        
    # return as np.array
    thetas = jnp.stack(thetas)

    # check dimensions and return
    assert thetas.shape == (num_iter+1, num_params), f'The shape of thetas was expected to be ({num_iter+1}, {num_params}), but the actual shape was {thetas.shape}. Please check your code.'
    return thetas



# Set number of parameters
num_params = 3*D + 3   # this is 3*D + 3

# specify the parameters of the MH algorithm
num_iterations = 2000
warm_up = int(0.5*num_iterations)
sigma = 0.5

theta = [w0, w1, v, sigma0, sigma1, tau]

# run sampler
thetas = metropolis(evaluate_log_joint, num_params, sigma, num_iterations, theta_init=theta, seed=0)

# plot resutls
xs = x_np
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
axes[0].plot(thetas)
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Parameter $\\theta$')
axes[0].set_title('Trace of parameter $\\theta$', fontweight='bold')

axes[1].hist(thetas, 30, density=True);
axes[1].plot(xs, np.exp(evaluate_log_joint(xs)), linewidth=3)
axes[1].set_xlabel('Parameter $\\theta$')
axes[1].set_title('Histogram of all samples', fontweight='bold')

axes[2].hist(thetas[warm_up:], 30, density=True);
axes[2].plot(xs, np.exp(evaluate_log_joint(xs)), linewidth=3)
axes[2].set_xlabel('Parameter $\\theta$');
axes[2].set_title('Histogram of all samples after warm-up', fontweight='bold');


TypeError: add got incompatible shapes for broadcasting: (200,), (6,).