In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import matplotlib.patches as patches

import core
import os

import seaborn as sns
sns.set_theme(context="talk", style="white")
unsafe_color = sns.color_palette("pastel")[3]
safe_color = sns.color_palette("pastel")[2]

In [None]:
net = core.CLF_QP_Net(4, 256, 2)
net.load_state_dict(torch.load('./logs/model.pth'))
net.eval()

res = 50 # resultion
num_iter = 30

x, y = np.meshgrid(np.linspace(-12, 12, res), np.linspace(-12, 12, res))
x = x.reshape(-1, 1)
y = y.reshape(-1, 1)

z = np.zeros(shape=(res, res))
vdot = np.zeros(shape=(res, res))
for j in tqdm(range(num_iter)):
    vx = np.random.uniform(-5, 5, size=x.shape)
    vy = np.random.uniform(-5, 5, size=y.shape)

    s = np.concatenate([x, y, vx, vy], axis=1).astype(np.float32)
    s = torch.from_numpy(s)
    _, V, Vdot = net(s)
    z = z + np.reshape(V.detach().numpy(), (res, res))
    vdot = vdot + np.reshape(Vdot.detach().numpy(), (res, res))

z = z / num_iter
vdot = vdot / num_iter

x = np.reshape(x, (res, res))
y = np.reshape(y, (res, res))

contours = plt.contourf(x, y, z, cmap="magma", levels=20)
plt.colorbar(contours, orientation="vertical")
plt.axis('off')

plt.savefig('contour.png')
plt.show()

100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [00:21<00:00,  1.38it/s]


In [None]:
fig, axs = plt.subplots(1, 2)
fig.set_size_inches(17, 10.5)
contours = axs[0].contourf(x, y, z, cmap="magma", levels=20)
plt.colorbar(contours, ax=axs[0], orientation="horizontal")

unsafe_sector = patches.Wedge((0, 0), 3.5, -45, -135, linewidth=2,
                             edgecolor='r', facecolor=unsafe_color, fill=False)
safe_sector = patches.Wedge((0, 0), 4.5, -45, -135, linewidth=2,
                             edgecolor='g', facecolor=safe_color, fill=False)
safe_circle = patches.Circle((0, 0), 8, linewidth=2,
                             edgecolor='g', facecolor=safe_color, fill=False)
unsafe_circle = patches.Circle((0, 0), 9, linewidth=2,
                             edgecolor='r', facecolor=safe_color, fill=False)
axs[0].add_patch(safe_sector)
axs[0].add_patch(unsafe_sector)
axs[0].add_patch(safe_circle)
axs[0].add_patch(unsafe_circle)

axs[0].plot([0], [0], color='r', linewidth=2, label='unsafe')
axs[0].plot([0], [0], color='green', linewidth=2, label='safe')

axs[0].set_xlabel('$p_x$')
axs[0].set_ylabel('$p_y$')
axs[0].set_title('$V$')
axs[0].legend()

contours = axs[1].contourf(x, y, np.maximum(0, vdot + 0.15*z), cmap="Greys", levels=10)
plt.colorbar(contours, ax=axs[1], orientation="horizontal")

unsafe_sector = patches.Wedge((0, 0), 3.5, -45, -135, linewidth=2,
                             edgecolor='r', facecolor=unsafe_color, fill=False)
safe_sector = patches.Wedge((0, 0), 4.5, -45, -135, linewidth=2,
                             edgecolor='g', facecolor=safe_color, fill=False)
safe_circle = patches.Circle((0, 0), 8, linewidth=2,
                             edgecolor='g', facecolor=safe_color, fill=False)
unsafe_circle = patches.Circle((0, 0), 9, linewidth=2,
                             edgecolor='r', facecolor=safe_color, fill=False)

axs[1].add_patch(safe_sector)
axs[1].add_patch(unsafe_sector)
axs[1].add_patch(safe_circle)
axs[1].add_patch(unsafe_circle)

axs[1].set_xlabel('$p_x$')
axs[1].set_ylabel('$p_y$')
axs[1].set_title('$max(dV/dt, 0)$')

plt.savefig('contour.png')
plt.show()