# framework:
    1. we have n_spins random variables, whose range is the real line
    2. the real distribution is a multivariate gaussian p_true ~ N(true_mean, true_cov)
    3. we see n_data samples
    4. we use score matching on the family of gaussians to infer true_mean, true_cov

NB this is a toy model, for gaussians score matching is equivalent to MLE, but we use a gradient descent here for the sake of training.

In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrand
from jax import grad, jit, vmap
import matplotlib.pyplot as plt

### data creation

In [2]:
seed = 1
key = jrand.PRNGKey(seed)

In [3]:
n_spins = 2
n_data = 50

In [4]:
key, subkey = jrand.split(key)
true_mean = jrand.uniform(subkey, shape = n_spins)

In [5]:
key, subkey = jrand.split(key)
temp = jrand.uniform(subkey, shape = [n_spins, n_spins], minval = 0)
true_cov = temp @ temp.T

In [6]:
key, subkey = jrand.split(key)
data = jrand.multivariate_normal(subkey, true_mean, true_cov, shape = n_data)

### densities and derivatives

In [7]:
def gaussian_density(x, mean, cov, precision):
    n_spins = len(mean)
    det_cov = jnp.linalg.det(cov)
    norm_const = 1.0 / jnp.sqrt((2 * jnp.pi)**n_spins * det_cov)

    diff = x - mean
    exponent = -0.5 * jnp.dot(jnp.dot(diff, precision), diff)
    density = norm_const * jnp.exp(exponent)

    return density

In [8]:
def grad_log_gaussian(x, mean, precision):
    diff = x - mean
    return -jnp.dot(precision, diff)

In [9]:
def lapl_log_gaussian(precision):
    return -jnp.trace(precision)

### score matching

remember we assume gaussianity...

- the score is: $\nabla_x \log p_\theta(x) = -\Sigma^{-1}(x - \mu)$  
- the laplacian is: $\Delta_x \log p_\theta(x) = -\operatorname{Tr}(\Sigma^{-1})$

plugging into the score matching loss gives:

$$
\mathcal{L}(\theta) = \mathbb{E}_{x \sim \text{data}} \left[ \frac{1}{2} \| \Sigma^{-1}(x - \mu) \|^2 - \operatorname{Tr}(\Sigma^{-1}) \right]
$$

---

### parametrization

to ensure that the covariance matrix ($\Sigma$) remains positive definite, we parametrize it using its Cholesky decomposition:

- the diagonal entries of the Cholesky factor ($L$) are exponentiated from `log_sigma_diag`  
- the lower-triangular off-diagonal entries are stored in `cholesky_factor_offdiag`  
- the full matrix is reconstructed as $\Sigma = L L^\top$

In [10]:
def score_matching_loss(params_raw, data):
    # extract mean vector from parameters
    mean = params_raw['mean']
    
    # extract log of diagonal of cholesky factor
    log_sigma_diag = params_raw['log_sigma_diag']
    
    # extract off-diagonal entries of cholesky factor
    cholesky_factor_offdiag = params_raw['cholesky_factor_offdiag']

    # get data dimensionality
    n_dim = len(mean)
    
    # initialize cholesky factor matrix L
    L = jnp.zeros((n_dim, n_dim), dtype=data.dtype)
    
    # set diagonal entries using exponentiated log_sigma_diag
    L = L.at[jnp.diag_indices(n_dim)].set(jnp.exp(log_sigma_diag))
    
    # if dimension > 1, set lower triangular off-diagonal entries
    if n_dim > 1:
        lower_tri_indices = jnp.tril_indices(n_dim, k=-1)
        assert cholesky_factor_offdiag.shape[0] == (n_dim * (n_dim - 1) // 2)
        L = L.at[lower_tri_indices].set(cholesky_factor_offdiag)

    # compute covariance matrix from cholesky factor
    cov = L @ L.T 
    
    # compute precision matrix (inverse of covariance)
    precision = jnp.linalg.inv(cov)

    # initialize total loss
    total_loss = 0.0
    
    # get number of data points
    N = data.shape[0]

    # loop over each data point
    for i in range(N):
        x_i = data[i, :]
        
        # compute score (gradient of log-density)
        score_val = grad_log_gaussian(x_i, mean, precision)
        
        # compute squared norm of score
        term1 = 0.5 * jnp.sum(score_val**2)

        # compute laplacian of log-density (does not depend on x)
        term2 = lapl_log_gaussian(precision)
        
        # add both terms to total loss
        total_loss += (term1 + term2)
    
    # return average loss over all data points
    return total_loss / N


In [11]:
num_off_diag_elements = n_spins * (n_spins - 1) // 2

params_raw = {
    'mean': jnp.zeros(n_spins),
    'log_sigma_diag': jnp.zeros(n_spins),
    'cholesky_factor_offdiag': jnp.zeros(num_off_diag_elements)
}

learning_rate = 0.001
num_iterations = 4000

loss_and_grad_fn = jax.value_and_grad(score_matching_loss)

In [12]:
# initialize list to store loss values
loss_history = []

# print start message
print("starting the optimization \n\n\n")

# main optimization loop
for i in range(num_iterations):
    # evaluate loss and its gradient
    loss_val, grads = loss_and_grad_fn(params_raw, data)
    
    # store current loss value
    loss_history.append(loss_val.item())

    # gradient descent update on parameters
    params_raw = jax.tree.map(lambda p, g: p - learning_rate * g, params_raw, grads)

    # print diagnostics every 500 iterations
    if (i + 1) % 500 == 0:
        # extract current estimates
        estimated_mean = params_raw['mean']
        estimated_log_sigma_diag = params_raw['log_sigma_diag']
        estimated_cholesky_factor_offdiag = params_raw['cholesky_factor_offdiag']

        # reconstruct lower triangular cholesky factor
        L_est = jnp.zeros((n_spins, n_spins))
        L_est = L_est.at[jnp.diag_indices(n_spins)].set(jnp.exp(estimated_log_sigma_diag))
        
        # set off-diagonal lower triangular elements if dimension > 1
        if n_spins > 1:
            lower_tri_indices = jnp.tril_indices(n_spins, k=-1)
            L_est = L_est.at[lower_tri_indices].set(estimated_cholesky_factor_offdiag)

        # compute estimated covariance matrix
        estimated_cov = L_est @ L_est.T

        # print estimated mean and covariance
        print(f"iteration {i+1}/{num_iterations}\n")
        print(f"  estimated Mean: {estimated_mean}\n")
        print(f"  estimated Covariance:\n{estimated_cov}\n\n\n")

starting the optimization 



iteration 500/4000

  estimated Mean: [0.05451116 0.4317398 ]

  estimated Covariance:
[[0.9454663  0.33368996]
 [0.33368996 0.19797575]]



iteration 1000/4000

  estimated Mean: [0.1948535  0.48528156]

  estimated Covariance:
[[0.90306747 0.31750458]
 [0.31750458 0.19178313]]



iteration 1500/4000

  estimated Mean: [0.2936692 0.5229742]

  estimated Covariance:
[[0.84912366 0.29692477]
 [0.29692477 0.18393223]]



iteration 2000/4000

  estimated Mean: [0.3574913 0.5473177]

  estimated Covariance:
[[0.82127243 0.28630427]
 [0.28630427 0.17988238]]



iteration 2500/4000

  estimated Mean: [0.3951242  0.56167203]

  estimated Covariance:
[[0.81047714 0.28218898]
 [0.28218898 0.17831337]]



iteration 3000/4000

  estimated Mean: [0.4163038 0.5697507]

  estimated Covariance:
[[0.8068265  0.28079757]
 [0.28079757 0.17778319]]



iteration 3500/4000

  estimated Mean: [0.42801183 0.57421666]

  estimated Covariance:
[[0.80566853 0.28035626]
 [0.28035626

In [15]:
true_mean, true_cov

(Array([0.40364635, 0.548895  ], dtype=float32),
 Array([[0.7159164 , 0.21423683],
        [0.21423683, 0.12737265]], dtype=float32))

In [13]:
estimated_mean-true_mean

Array([0.03079867, 0.02777547], dtype=float32)

In [14]:
estimated_cov-true_cov

Array([[0.0893954 , 0.0659835 ],
       [0.0659835 , 0.05019057]], dtype=float32)