In [12]:
import warnings
import gymnasium as gym
import torch
import numpy as np
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import Actor, Critic
from tqdm import tqdm
import tianshou

warnings.filterwarnings("ignore", category=DeprecationWarning)


DEVICE = "cuda:1"
# 定义环境
env = gym.make("LunarLander-v2")
state_shape = env.observation_space.shape
action_shape = env.action_space.n
net = Net(state_shape, hidden_sizes=[1024] * 2, device=DEVICE)
actor = Actor(net, action_shape, device=DEVICE)
critic = Critic(net, device=DEVICE)

dist_fn = torch.distributions.Categorical


policy = PPOPolicy(
    actor=actor,
    critic=critic,
    dist_fn=dist_fn,
    optim=None,
    discount_factor=0.99,
    max_grad_norm=0.5,
    eps_clip=0.2,
    vf_coef=0.5,
    ent_coef=0.01,
    reward_normalization=True,
    action_space=env.action_space,
    action_scaling=False,
    deterministic_eval=True,
).to(DEVICE)
policy.eval()
possible_actions = torch.arange(env.action_space.n)
def get_loss(env, policy, obs):
    if isinstance(policy, PPOPolicy):
        assert isinstance(
            env.action_space, gym.spaces.Discrete
        ), "Only discrete action spaces supported for loss function"

        obs = torch.tensor(obs, dtype=torch.float32).to(DEVICE)
        batch = tianshou.data.Batch(obs=obs.unsqueeze(0),info="")
        with torch.no_grad():
            out = policy(batch)
        probs = out.logits.squeeze(0).detach().cpu().numpy()

        return probs.max() - probs.min()

    raise NotImplementedError(f"Model type {type(policy)} not supported")


obs, _ = env.reset()
batch = tianshou.data.Batch(obs=obs.reshape(1,-1), info="")
print(policy.forward(batch))
print(get_loss(env,policy,obs))

Batch(
    logits: tensor([[0.2432, 0.2141, 0.2676, 0.2751]], device='cuda:1',
                   grad_fn=<SoftmaxBackward0>),
    act: tensor([3], device='cuda:1'),
    state: None,
    dist: Categorical(probs: torch.Size([1, 4])),
)
0.060961783
