In [2]:

import numpy as np
import matplotlib.pyplot as plt
# If you need to import from a local utils.py, uncomment and adjust the following lines:
import sys
import os
import jax.numpy as jnp
# Construct the full path to the folder
folder_path = r'C:\Users\Petrb\Desktop\DTU\3rdSemester\02477_BAYESIAN_MACHINE_LEARNING'

# Add the folder to the Python path
sys.path.append(folder_path)

# Now you can import the utils module
from utils import *

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import bernoulli

## Part 1: Multi-class classification

Consider the following linear model for multi-class classification with $ K = 3 $ classes:

$$
y_n | \mathbf{f}_n \sim \text{Categorical}(\text{softmax}(\mathbf{f}_n)), \tag{1}
$$

$$
\mathbf{f}_n = \mathbf{W} \phi(x_n), \tag{2}
$$

$$
W_{ij} \sim \mathcal{N}(0, \alpha^{-1}), \tag{3}
$$

where $ y_n \in \{1,2,3\} $, $ x_n \in \mathbb{R} $, $ \alpha > 0 $ is a hyperparameter, and $ \mathbf{W} $ are the parameters of interest. The feature transformation $ \phi(x) $ is given by $ \phi(x) = \begin{bmatrix} 1 & x \end{bmatrix}^T $ such that $ \mathbf{W} \in \mathbb{R}^{K \times D} $ for $ D = 2 $.

---

### Question 1.1: Identify the prior and likelihood of the model.


Let

$$
\hat{\mathbf{W}}_{\text{MAP}} = \begin{bmatrix}
-0.5 & -2.0 \\
3.0 & 0.0 \\
1.0 & 1.0
\end{bmatrix} \tag{4}
$$

be a MAP-estimator for the model given in eq. (1)–(3) for some dataset $ \mathcal{D} $ (not given).

---

### Question 1.2: Use the plugin approximation with $ \hat{\mathbf{W}}_{\text{MAP}} $ to compute the posterior predictive distribution for $ x^* = -1 $.




In [23]:
w_MAP = jnp.array([[-0.5, -2], [3.0, 0.0], [1.0, 1.0]])
print(f"w_MAP: \n {w_MAP}")
print(f"w_MAP shape: {w_MAP.shape}")

x_star = -1.0

def f(x, w):
    """
    Function to compute the output of a linear model with weights w and input x.
    :param x: Input data (2D array).
    :param w: Weights (2D array).
    :return: Output of the linear model (1D array).
    """
    return jnp.dot(x, w.T)


def phi(x):
    return jnp.array([1, x])

pred = w_MAP @ phi(x_star)  # Compute the prediction for x_star using the weights w_MAP
#print(phi(x_star) @ w_MAP.T)
print(f"Pred for x_star: {pred}")
print(f"Pred for x_stars shape: {pred.shape}")


softmax = lambda x: jnp.exp(x) / jnp.sum(jnp.exp(x)) 

print(f"Softmax: {softmax(pred)}")  

#print(f"Prob: {prob}")
#print(f"Prob shape: {prob.shape}")



w_MAP: 
 [[-0.5 -2. ]
 [ 3.   0. ]
 [ 1.   1. ]]
w_MAP shape: (3, 2)
Pred for x_star: [1.5 3.  0. ]
Pred for x_stars shape: (3,)
Softmax: [0.17529039 0.785597   0.03911257]


Let $ \mathbf{W}^{(i)} \sim q(\mathbf{W}) $ for $ i = 1, 2, 3 $ be samples from a variational approximation of the posterior, i.e. $ p(\mathbf{W}|\mathcal{D}) \approx q(\mathbf{W}) $:

$$
\mathbf{W}^{(1)} = \begin{bmatrix}
-0.15 & -1.92 \\
3.2 & 0.45 \\
1.37 & 0.8
\end{bmatrix}, \quad
\mathbf{W}^{(2)} = \begin{bmatrix}
-0.31 & -2.03 \\
2.98 & 0.08 \\
1.03 & 1.29
\end{bmatrix}, \quad
\mathbf{W}^{(3)} = \begin{bmatrix}
-0.35 & -1.98 \\
3.09 & 0.07 \\
1.3 & 0.96
\end{bmatrix}. \tag{5}
$$



### Question 1.3: Compute a Monte Carlo estimate of the posterior predictive distribution for $ x^* = -1 $ using samples given above.



In [31]:
w_1 = jnp.array([[-0.15, -1.92], [3.2, 0.45], [1.37, 0.8]])
w_2 = jnp.array([[-0.31, -2.03], [2.98, 0.08], [1.03, 1.29]])
w_3 = jnp.array([[-0.35, -1.98], [3.09, 0.07], [1.3, 0.96]])


list_of_weights = [w_1, w_2, w_3]

def phi(x):
    return jnp.array([1, x])


probs = []


for i in range(3):
    pred = list_of_weights[i] @ phi(x_star)  # Compute the prediction for x_star using the weights w_MAP
    probs.append(softmax(pred))

print(f"Probs: {np.mean(probs, axis=0)}")

Probs: [0.2229536  0.72390676 0.05313967]


The predictive distribution $ p(y^* | \mathcal{D}, x^* = 3) $ is given in the table below:

| $ k $ | $ p(y^* = k \mid x^*) $ |
|--------|---------------------------|
| 1      | 0.00                      |
| 2      | 0.27                      |
| 3      | 0.73                      |

---

### Question 1.4: Determine the entropy and confidence of the posterior predictive distribution for $ x^* = 3 $ given in the table above.

---


In [49]:
def confidence(p):
    """
    Computes the confidence for each predictive distribution.

    The confidence is defined as the maximum predicted probability for each sample:
        confidence(x^*) = max_k p(y^*=k | x^*, D)
    where D is the training data.

    Parameters
    ----------
    p : jax.numpy.ndarray
        Posterior predictive probabilities for each sample and class.
        Shape: (N, K), where N is the number of prediction points and K is the number of classes.

    Returns
    -------
    conf : jax.numpy.ndarray
        Confidence for each prediction point.
        Shape: (N,)

    Equation
    --------
    conf_n = max_k p_{n,k}
    """
    return jnp.max(p, axis=1)

def entropy(p):
    """
    Computes the predictive entropy for each predictive distribution.

    The entropy measures the uncertainty of the predictive distribution:
        entropy(x^*) = -sum_k p(y^*=k | x^*, D) * log(p(y^*=k | x^*, D))

    Parameters
    ----------
    p : jax.numpy.ndarray
        Posterior predictive probabilities for each sample and class.
        Shape: (N, K), where N is the number of prediction points and K is the number of classes.

    Returns
    -------
    ent : jax.numpy.ndarray
        Predictive entropy for each prediction point.
        Shape: (N,)

    Equation
    --------
    ent_n = -sum_k p_{n,k} * log(p_{n,k})
    """
    # Use jnp.where to avoid log(0) by only computing log where p > 0
    return -jnp.sum(jnp.where(p > 0, p * jnp.log(p), 0.0), axis=1)


In [50]:
p_hat = jnp.array([0.0, 0.27, 0.73]).reshape(1, 3)  
print(f"p_hat: {p_hat}")
print(f"p_hat shape: {p_hat.shape}")


confidence(p_hat)
print(f"Confidence: {confidence(p_hat)}")
entropy(p_hat)
print(f"Entropy: {entropy(p_hat)}")

p_hat: [[0.   0.27 0.73]]
p_hat shape: (1, 3)
Confidence: [0.73]
Entropy: [0.58325887]


---

### Question 1.5: Suppose the value of the hyperparameter $ \alpha $ is increased by a factor of 10. Explain in your own words how you would expect the MAP-estimate to change and why.