# Test Door Key Offline Training with d3rlpy and Decision Transformer

We will use the Door Key 16x16 environment from Minigrid Gym to test the Decision Transformer algorithm from d3rlpy.

In [None]:
# Test if we are running on CoLab or not
if 'google.colab' in str(get_ipython()):
  print('Running on CoLab')
  !apt-get install -y xvfb ffmpeg > /dev/null 2>&1
  %pip install pyvirtualdisplay pygame moviepy > /dev/null 2>&1
  %pip install d3rlpy
else:
  print('Not running on CoLab')

In [None]:
!nvidia-smi

In [None]:
# Directory creation
import os
path = "./models"
isExist = os.path.exists(path)
if not isExist:
  os.makedirs(path)

path = "./videos/video-doorkey-d3rlpy"
isExist = os.path.exists(path)
if not isExist:
  os.makedirs(path)

In [None]:
import d3rlpy
import gymnasium as gym
import torch
import torch.nn as nn
from d3rlpy.models.encoders import EncoderFactory

env_key = "MiniGrid-DoorKey-8x8-v0"

def create_env(env_key, max_episode_steps=1000, render_mode=None):
    env = gym.make(env_key, max_episode_steps=max_episode_steps, render_mode=render_mode)
    print(max_episode_steps)
    if (render_mode is None):
        env = gym.wrappers.FilterObservation(env, filter_keys=['image','direction']) 
        env = gym.wrappers.FlattenObservation(env)
        #env = gym.wrappers.NormalizeObservation(env)   
    return env

env = create_env(env_key, max_episode_steps=1000)
eval_env = create_env(env_key, max_episode_steps=100)
print(env.observation_space)

class CustomEncoder(nn.Module):
    def __init__(self, observation_shape):
        super().__init__()
        print(observation_shape)
        self.feature_size = 16
        self.fc1 = nn.Linear(observation_shape[0], 128)
        self.fc1dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 64)
        self.fc2dropout = nn.Dropout(0.5)
        self.fc3 = nn.Linear(64, 32)


    def forward(self, x):
        h = torch.relu(self.fc2dropout(self.fc1(x)))
        h = torch.relu(self.fc2dropout(self.fc2(h)))
        h = torch.relu(self.fc3(h))
        return h
    
class CustomEncoderFactory(EncoderFactory):

    def create(self, observation_shape):
        return CustomEncoder(observation_shape)

    @staticmethod
    def get_type() -> str:
        return "custom"

dqn = d3rlpy.algos.DQNConfig(
    encoder_factory=CustomEncoderFactory(),
    batch_size=100,
    gamma=0.9,
    target_update_interval=1000,
    learning_rate=2.5e-4
).create(device="cuda:0")


In [None]:
import numpy as np

from collections import deque
from typing import Deque, List, Sequence, Tuple

from typing_extensions import Protocol

from d3rlpy.dataset.components import EpisodeBase

from d3rlpy.dataset.buffers import BufferProtocol

from d3rlpy.dataset.writers import ExperienceWriter, _ActiveEpisode, WriterPreprocessProtocol
from d3rlpy.dataset.components import Signature

class PriorityBuffer:
    r"""Priority buffer using a sorted list.

    Args:
        limit (int): buffer capacity.
    """
    _transitions: List[Tuple[EpisodeBase, int]]
    _episodes: List[EpisodeBase]
    _limit: int

    def __init__(self, limit: int):
        self._limit = limit
        self._transitions = []
        self._episodes = []

    def get_priority(self, episode: EpisodeBase) -> float:
        return episode.rewards.mean()

    def append(self, episode: EpisodeBase, index: int) -> None:
        priority = self.get_priority(episode)
        entry = (priority, (episode, index))
        self._transitions.append(entry)
        self._transitions.sort(key=lambda x: x[0])  # Sort by priority
        if not self._episodes or episode is not self._episodes[-1]:
            self._episodes.append(episode)
        if len(self._transitions) > self._limit:
            self._remove_lowest_priority()

    def _remove_lowest_priority(self) -> None:
        _, (episode, _) = self._transitions.pop(0)
        if episode is self._episodes[0]:
            self._episodes.pop(0)

    @property
    def episodes(self) -> Sequence[EpisodeBase]:
        return self._episodes

    @property
    def transition_count(self) -> int:
        return len(self._transitions)

    def __len__(self) -> int:
        return len(self._transitions)

    def __getitem__(self, index: int) -> Tuple[EpisodeBase, int]:
        _, (episode, idx) = self._transitions[index]
        return episode, idx

class CustomReplayBuffer(d3rlpy.dataset.ReplayBuffer):

    def clip_episode(self, terminated: bool) -> None:
        r"""Clips the current episode.

        Args:
            terminated: Flag to represent environment termination.
        """

        episode_to_remove = None
        # Check if the episode's reward is 0 or negative
        if not terminated and self._writer._active_episode.rewards.mean() <= 0:
            episode_to_remove = self._writer._active_episode
            
        self._writer.clip_episode(terminated)

        if episode_to_remove is not None:
            # Remove all transitions associated with the episode to remove
            self._buffer._transitions = [(ep, idx) for ep, idx in self._buffer._transitions if ep is not episode_to_remove]
            self._buffer.episodes.remove(episode_to_remove)  


class CustomWriterPreprocess(d3rlpy.dataset.WriterPreprocessProtocol):

    def process_observation(self, observation: d3rlpy.types.Observation) -> d3rlpy.types.Observation:
        return observation

    def process_action(self, action: np.ndarray) -> np.ndarray:
        #print(action)
        return action

    def process_reward(self, reward: np.ndarray) -> np.ndarray:
        if (reward > 0.1):
            print(reward)
        return reward
    
writer_preprocessor = CustomWriterPreprocess()

#buffer = PriorityBuffer(200)
buffer = d3rlpy.dataset.FIFOBuffer(10000)
buffer = CustomReplayBuffer(
    buffer,
    env=env, 
    #observation_signature=observation_signature,
    writer_preprocessor=writer_preprocessor
)

#buffer = d3rlpy.dataset.create_fifo_replay_buffer(
#    limit=10000, env=env)

explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(0.9, 0.3)
dqn.fit_online(
    env,
    buffer,
    explorer,
    n_steps=1000000,  # train for 100K steps
    eval_env=eval_env,
    n_steps_per_epoch=100000,  # evaluation is performed every 1K steps
    update_start_step=50000,  # parameter update starts after 1K steps
    update_interval=10
)

In [None]:
import gymnasium as gym

import numpy as np
from gym.wrappers import RecordVideo

# start virtual display
d3rlpy.notebook_utils.start_virtual_display()

env = create_env(env_key, max_episode_steps=1000)
# wrap RecordVideo wrapper
env_video = RecordVideo(gym.make(env_key, render_mode="rgb_array"), './videos/video-doorkey-d3rlpy')

seed = 1

# interaction
observation, reward = env.reset(seed=seed)
env_video.reset(seed=seed)

explorer = d3rlpy.algos.ConstantEpsilonGreedy(0.3)
i = 0
done = False

while True:
    #action = dqn.predict(np.expand_dims(observation, axis=0))[0]
    x = np.expand_dims(observation, axis=0)
    action = explorer.sample(dqn, x, 0)[0]

    observation, reward, done, truncated, _ = env.step(action)
    env_video.step(action)

    if done:
        print("reward:", reward)
        print("DONE!!!")
        env_video.reset(seed=seed)
        break
    elif truncated:
        print("Truncated")
        break


d3rlpy.notebook_utils.render_video("./videos/video-doorkey-d3rlpy/rl-video-episode-0.mp4")