*Note for CPU energy monitoring:* To get the CPU energy measure with zeus: `sudo su` then `jupyter-lab --allow-root` and modify `time_complete_sampling()` such that `gpu_energy` becomes `cpu_energy`. We also set `cpu_device = jax.devices('cpu')[0]` and `with jax.default_device(cpu_device):` because we cannot set `jax_platforms=cpu`: if we totally hide the GPU, zeus will complain.

In [None]:
import jax
import numpy as np
from mrfx.experiments import time_complete_sampling, plot_benchmark
from mrfx.samplers import ChromaticGibbsSampler
from mrfx.models import Potts

# cpu_device = jax.devices('cpu')[0]

# with jax.default_device(cpu_device):
key = jax.random.PRNGKey(0)

K = 2
beta = 1.0
potts_model = Potts(K=K, beta=beta, neigh_size=1)

Ks = np.arange(2, 8)
sizes = [(2**e, 2**e) for e in range(4, 10)]
reps = 10

key, subkey = jax.random.split(key, 2)
times, n_iterations, samples, energy = time_complete_sampling(
    ChromaticGibbsSampler,
    Potts,
    subkey,
    Ks,
    sizes,
    reps,
    kwargs_sampler={
        "eps": 0.05,
        "max_iter": 10000,
        "color_update_type": "vmap_in_color",
    },
    kwargs_model={"beta": beta},
    exp_name="Gibbs_sampler_GPU",
    with_energy=True,
)

In [None]:
plot_benchmark(
    Ks, sizes, times, title="Times for Chromatic Gibbs sampling", ylabel="times (s)"
)

In [None]:
plot_benchmark(
    Ks, sizes, energy, title="Energy for Chromatic Gibbs sampling", ylabel="energy (J)"
)