## Import libraries

In [1]:
import torch
import torchrl
from tensordict.nn import InteractionType
from torch.distributions.normal import Normal
import torchrl.envs
from torch.distributions.categorical import Categorical
import numpy as np
import gymnasium as gym

## Infrence

In [None]:
env = GymEnv("CartPole-v1", render_mode="human")

model = torch.load("models\pole168_steps_final.pth")
model.cpu()
model.eval()

with torch.inference_mode(), torchrl.envs.utils.set_exploration_type(InteractionType.MEAN):
    env.rollout(500, break_when_any_done=False, policy=model)
    env.render()

env.close()

In [None]:
env = gym.make("LunarLander-v2", render_mode="human", wind_power=15, turbulence_power=1.5)

model = torch.jit.load("models\lander_final_sv.pth")
model.cpu()
model.eval()

with torch.inference_mode(), torchrl.envs.utils.set_exploration_type(InteractionType.MEAN):
    for i in range(5):
        observation, _ = env.reset()
        state = torch.tensor(np.array([observation]), dtype=torch.float32)
        done = False
        score = 0
        env.render()
        while not done:
            dist = model(state)
            dist = Categorical(dist)
            action = dist.sample()
            observation, r, terminated, truncated, _ = env.step(action.item())
            score += r
            done = terminated or truncated
            state = torch.tensor(np.array([observation]), dtype=torch.float32)
        print(score)
env.close()

In [None]:
env = gym.make("LunarLander-v2", render_mode="human", wind_power=15, turbulence_power=1.5, continuous=True)

model = torch.jit.load("models\lander_final_sv.pth")
model.cpu()
model.eval()

with torch.inference_mode(), torchrl.envs.utils.set_exploration_type(InteractionType.MEAN):
    for i in range(5):
        observation, _ = env.reset()
        state = torch.tensor(np.array([observation]), dtype=torch.float32)
        done = False
        score = 0
        env.render()
        while not done:
            res = model(state)
            loc, scale = res.chunk(2, -1)
            scale = torch.nn.functional.softplus(scale).clamp_min(1e-4)
            dist = Normal(loc, scale)
            action = dist.sample()
            observation, r, terminated, truncated, _ = env.step(action.item())
            score += r
            done = terminated or truncated
            state = torch.tensor(np.array([observation]), dtype=torch.float32)
        print(score)
env.close()