In [None]:
import os
import os.path as osp
import numpy as np
import pickle
import matplotlib.pyplot as plt
import numpy.linalg as LA
import torch


class VFunc:
    def __init__(self):  
        self.policy = torch.load("./saved_model/policy_drone_xz_LBAC_draw_clbf_epi2000_seed1.pkl")
        self.critic = torch.load("./saved_model/critic_drone_xz_LBAC_draw_clbf_epi2000_seed1.pkl")
        self.device = torch.device("cpu")
        
    def compute_v(self, state_array):
        state = torch.FloatTensor(state_array).to(self.device).unsqueeze(0)
        action, _, _ = self.policy.sample(state)
        # state = torch.unsqueeze(state, dim=0)
        # action = torch.unsqueeze(action, dim=0)
        v = self.critic(state, action)
        v_val = min(v[0].item(), v[1].item())
        return v_val
    
v_func = VFunc()
v_func.compute_v(np.array([0, 0, 0, 0]))

In [None]:
unit_num = 40
nx, ny = (3 * unit_num, 2 * unit_num)
xs = np.linspace(-2, 1, nx)
ys = np.linspace(0, 2, ny)

zs = np.zeros((ny, nx))

for num_x in range(xs.shape[0]):
    for num_y in range(ys.shape[0]):
        zs[num_y, num_x] = v_func.compute_v(np.array([xs[num_x], ys[num_y], 0, 0]))
        zs[num_y, num_x] = min(max(zs[num_y, num_x], -2000), 0)

In [None]:
fig, axs = plt.subplots(1, figsize=(8, 8))
rectangle_lower = plt.Rectangle((-1, 0.2), 0.5, 0.8, color='white', alpha=0.2)
rectangle_higher = plt.Rectangle((0, 1.3), 1, 0.5, color='white', alpha=0.2)
ground = plt.Rectangle((-2, 0), 3, 0.2, color='white', alpha=0.2)
# plt.pcolormesh(xs, ys, zs)
levels = np.arange(-2000, 100, 100)
plt.contourf(xs, ys, zs, linewidths=0.5, levels=levels)

plt.colorbar()
plt.gca().add_patch(rectangle_lower, )
plt.gca().add_patch(rectangle_higher)
plt.gca().add_patch(ground)
axs.tick_params(axis="both", which="major", labelsize=15)
fig.savefig("v_func.pdf", dpi=300)
plt.show()