Import packages

In [1]:
import sys, os
import pathlib
__location__ = os.getcwd()
__location__ = pathlib.Path(__location__)

sys.path.append('C:\\Users\\doore\\project\\snake_RL\\dmc\\domains')
import snake
from snake.envs.SnakeEnv import SnakeEnv # for ray env register
import gymnasium as gym
import ray

# import rl 알고리즘
from ray.rllib.algorithms.ppo import PPOConfig 
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.policy.policy import Policy

from ray.tune.registry import register_env
from ray.tune.logger import pretty_print

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


Make Snake Env

In [2]:
env = gym.make('snake/SnakeEnv-v1', render_mode="human")

Load pre-learned policy

In [7]:
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch

# policy = Policy.from_checkpoint('C:/Users/doore/ray_results/PPO_snake_2023-03-09_21-45-50djfsrxig/checkpoint_003426')


torch, nn = try_import_torch()
class RNNModel(TorchRNN, nn.Module):

    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        fc_size=64,
        lstm_state_size=256,
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.fc_size = fc_size
        self.lstm_state_size = lstm_state_size

        # Build the Module from fc + LSTM + 2xfc (action + value outs).
        self.fc1 = nn.Linear(self.obs_size, self.fc_size)
        self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=True)
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None

    @override(ModelV2)
    def get_initial_state(self):
        # TODO: (sven): Get rid of `get_initial_state` once Trajectory
        #  View API is supported across all of RLlib.
        # Place hidden states on same device as model.
        h = [
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.
        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).
        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        x = nn.functional.relu(self.fc1(inputs))
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
        )
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

ModelCatalog.register_custom_model("MyRNN", RNNModel)
rnn_policy = Policy.from_checkpoint('C:\\Users\\doore\\ray_results\\PPO_snake_2023-03-11_00-34-28wx7bcj33\\checkpoint_000136')

Do simulation

In [7]:
_obs, _ = env.reset()

for epi in range(10):
    for i in range(400):
        _act = policy['default_policy'].compute_single_action(_obs)
        _obs, _reward, _, _, _ = env.step(_act[0])
    _obs, _ = env.reset()
    
env.close()

Check Model parameters

In [8]:
from ray.rllib.models.modelv2 import ModelV2

print(rnn_policy['default_policy'].model.get_parameter)

<bound method Module.get_parameter of RNNModel(
  (fc1): Linear(in_features=32, out_features=64, bias=True)
  (lstm): LSTM(64, 256, batch_first=True)
  (action_branch): Linear(in_features=256, out_features=28, bias=True)
  (value_branch): Linear(in_features=256, out_features=1, bias=True)
)>
