In [1]:
%load_ext autoreload
%autoreload 2

# Deep Gaussian Markov Random Fields on graph (the Graph DGMRF model) on the Wikipedia example

The original article is [here](https://proceedings.mlr.press/v162/oskarsson22a/oskarsson22a.pdf).

In [2]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]=""

from functools import partial
from torch.distributions import MultivariateNormal
import torch
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import optax
from jax_tqdm import scan_tqdm
import copy
import equinox as eqx
import numpy as np


key = jax.random.PRNGKey(0)

In [3]:
try:
    gpu_device = jax.devices('gpu')[0]
    cpu_device = jax.devices('cpu')[0]
    print(cpu_device, gpu_device)
except:
    cpu_device = jax.devices('cpu')[0]
    print(cpu_device)

TFRT_CPU_0 cuda:0


## Load the wikipedia graph data

In [21]:
edges_mat = np.genfromtxt("./chameleon_edges.csv", delimiter=",", skip_header=1).astype(np.int32)
y = np.log(np.genfromtxt("./chameleon_target.csv", delimiter=",", skip_header=1) + 1e-6)[:, 1]

N = y.shape[0]

Convert the edge list to adjacency matrix

In [22]:
def update_adjacency_matrix(adjacency_matrix, edge):
    i, j = edge
    adjacency_matrix = adjacency_matrix.at[i, j].set(1)
    adjacency_matrix = adjacency_matrix.at[j, i].set(1)
    return adjacency_matrix, None

def edge_list_to_adjacency_matrix(edge_list, num_nodes):
    edge_array = jnp.array(edge_list)
    adjacency_matrix = jnp.zeros((num_nodes, num_nodes), dtype=jnp.int8)
    adjacency_matrix, _ = jax.lax.scan(update_adjacency_matrix, adjacency_matrix, edge_array)

    return adjacency_matrix

In [23]:
with jax.default_device(cpu_device):
    A = edge_list_to_adjacency_matrix(edges_mat, y.shape[0])
A = jax.device_put(A, gpu_device)

Remove the isolated nodes

Compute the diagonal of the degree matrix

In [24]:
D = jnp.sum(A, axis=1).astype(jnp.int32)

As stated in the article, 50% of nodes will be unobserved, randomly and uniformly chosen

In [25]:
mask = jnp.zeros_like(y)
key, subkey = jax.random.split(key, 2)
idx_unobserved = jax.random.choice(subkey, jnp.arange(y.shape[0]), shape=(y.shape[0] // 2,), replace=False)
mask = mask.at[idx_unobserved].set(1)
y = jnp.where(mask == 0, y, 0)

Optionally add noise

In [26]:
true_sigma_noise = 0.01
key, subkey = jax.random.split(key, 2)
y = y + jax.random.normal(subkey, y.shape) * true_sigma_noise

## Define the DGMRF components

In [27]:
from dgmrf.models import DGMRF
from dgmrf.utils import get_adjacency_matrix_lattice

**Note** when $L>1$ we get an unwanted smoothing effect for the Graph DGMRF on image

In [28]:
L = 3
Nq = 10

In [29]:
y = y.flatten()
mask = mask.flatten()

In [30]:
key, subkey = jax.random.split(key, 2)

dgmrf = DGMRF(
    subkey,
    L,
    A_D=(A, D),
    log_det_method="eigenvalues"
)

dgmrf_params, dgmrf_static = eqx.partition(dgmrf, lambda x:eqx.is_inexact_array(x)) # and not eqx.is_exact because the layer would consider self.key as a parameter!

## Variational inference to compute model parameters

The initial values of the mean and std for the variational distribution are given lines 572 and 573 of the orignal code.

In [31]:
key, subkey1, subkey2 = jax.random.split(key, 3)
params_init = {
    "dgmrf":dgmrf_params,
    "log_sigma":jnp.log(true_sigma_noise),
    "nu_phi":y, #jnp.ones((H * W,)) * jax.random.uniform(subkey1, (H*W,), minval=-3, maxval=3), # * 0.3,
    "log_S_phi":jax.nn.softplus(jax.random.normal(subkey2, (N,)) * 0.01), # jnp.ones((H * W,)) * 1
}
static = {
    "dgmrf": dgmrf_static
}


In [32]:
n_iter = 1000
lr_graph = 1e-3

lr = lr_graph
tx = optax.multi_transform(
    {
     'dgmrf': optax.adam(lr), 'log_sigma': optax.adam(0),
     'nu_phi': optax.adam(lr), 'log_S_phi': optax.adam(lr)},
    {'dgmrf':"dgmrf", 'log_sigma':'log_sigma', 'nu_phi':'nu_phi', 'log_S_phi':'log_S_phi'} # This what would be returned by the label_fn; i.e. this is the correct syntax when directly giving labels and when params is dict
)
opt_state = tx.init(params_init)

In [33]:
dgmrf = eqx.combine(params_init["dgmrf"], static["dgmrf"])

In [34]:
from dgmrf.losses import dgmrf_elbo
from dgmrf.train import train_loop

key, subkey = jax.random.split(key, 2)
params_final, loss_val = train_loop(dgmrf_elbo, y, n_iter, params_init, static, tx, opt_state, subkey, N, Nq, mask=mask)

  0%|          | 0/1000 [00:00<?, ?it/s]

The means of the variational distribution will be the inferred values

In [38]:
rmse_nu_phi = jnp.sqrt(jnp.mean((params_final["nu_phi"] - y) ** 2))
print("RMSE with the means of the variational distribution", rmse_nu_phi)

RMSE with the means of the variational distribution 0.17576729


## Posterior mean computation with conjugate gradient
**TODO**

In [39]:
dgmrf = eqx.combine(params_final["dgmrf"], static["dgmrf"])

In [46]:
xpost_mean_cg = dgmrf.get_post_mu(y, params_final["log_sigma"], params_final["nu_phi"], mask=mask, method="cg")
xpost_mean_cg

Array([nan, nan, nan, ..., nan, nan, nan], dtype=float32)

In [45]:
rmse_cg = jnp.sqrt(jnp.mean((xpost_mean_cg - y) ** 2))
print("RMSE with the posterior mean from conjugate gradient", rmse_cg)

RMSE with the posterior mean from conjugate gradient nan


## Visualizations
**TODO**