$$C(\theta) = \frac{1}{2}\theta^T A\theta - b^T\theta$$

Then we know minimum is at $\theta^* = A^{-1}b$


Gradient: $\nabla C = A\theta - b$


Hessian: $H(\theta) = A$

FIM: $F = \mathbb{E}_{x \sim p}\left[(\nabla_\theta \log p(x | \theta))(\nabla_\theta \log p(x | \theta))^\top\right].$


OBS: Fisher Information Matrix F is the Hessian of KL-divergence between two distributions p(x|θ) and p(x|θ′), with respect to θ′, evaluated at θ′=θ.

----------------

**GD**:    $x_{new} = x_{old} - \alpha \nabla f(x_{old})$


**Newton's**:    $x_{new} = x_{old} - \alpha H^{-1}(x_{old}) \nabla f(x_{old})$

**Natural Gradient**: $x_{new} = x_{old} - \alpha F^{-1}(x_{old}) \nabla f(x_{old})$

In [7]:
import jax.numpy as jnp
from jax import grad, jit, random, vmap
from jax.scipy.special import logit, expit as sigmoid

# Data simulation (for the sake of the example)
key = random.PRNGKey(0)
num_samples = 1000
dim = 5
X = random.normal(key, (num_samples, dim))
true_w = jnp.array([1.5, -2.5, 2.0, 1.0, -1.0])
true_b = -0.5
logits = jnp.dot(X, true_w) + true_b
probs = sigmoid(logits)
y = random.bernoulli(key, probs)

# Model and log likelihood
def model(params, x):
    w, b = params
    return sigmoid(jnp.dot(x, w) + b)

def neg_log_likelihood(params, x, y):
    p = model(params, x)
    return -jnp.mean(y * jnp.log(p) + (1 - y) * jnp.log(1 - p))

# Gradient of the log likelihood
grad_log_lik = grad(neg_log_likelihood)

@jit
def compute_fisher_information(params, x, y):
    # Gradient of the log likelihood for each data point
    grads = vmap(grad_log_lik, (None, 0, 0))(params, x, y)
    
    # Compute FIM
    fim = jnp.mean(jnp.einsum('ij,ik->ijk', grads, grads), axis=0)
    
    return fim

# Initialize parameters
params = (jnp.zeros(dim), jnp.array(0.0))

fim = compute_fisher_information(params, X, y)
print(f"Fisher Information Matrix: \n{fim}")


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1000,5].
The error occurred while tracing the function compute_fisher_information at /var/folders/gg/pt9rjjws0_b60pgd3qk984140000gn/T/ipykernel_67205/3185997918.py:28 for jit. This concrete value was not available in Python because it depends on the values of the arguments params[0], params[1], x, and y.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError