In [None]:
import einops
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

In [None]:
# We will denote unnormalized action coordinates as `u` and `v`.
grid_size = 50
action_u = torch.linspace(-1.0, 1.0, steps=grid_size)
du = action_u[1] - action_u[0]
action_v = torch.linspace(-1.0, 1.0, steps=grid_size)
dv = action_v[1] - action_v[0]
action_v_grid, action_u_grid = torch.meshgrid(action_u, action_v)

action_us = einops.rearrange(action_u_grid, "H W -> (H W)")
action_vs = einops.rearrange(action_v_grid, "H W -> (H W)")
ys = einops.rearrange([action_us, action_vs], "C HW -> () HW C")

mu = torch.tensor([0., 0.])
std = torch.tensor([0.1, 0.1])
cov = torch.diag(std ** 2)
distr = MultivariateNormal(mu, cov)

Zs = torch.exp(distr.log_prob(ys))
Z_grid = Zs.squeeze(0)
Z_grid = einops.rearrange(Z_grid, "(H W) -> H W", H=grid_size)
fig = plt.figure()
mesh = plt.pcolormesh(
    action_u_grid.numpy(),
    action_v_grid.numpy(),
    Z_grid.numpy(),
    cmap="magma",
    shading="auto",
)
plt.colorbar(label=r"$p_\theta(x)$")
plt.grid(False)
plt.axis("off")

area = Z_grid.sum() * du * dv
print(area)